(Arrays.asList(1, 3)), sets.get(5));
+ }
+
@SuppressWarnings("unchecked")
@Test
public void foldByKey() {
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 4e7c34e6d1ada..3aab88e9e9196 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -20,11 +20,11 @@ package org.apache.spark
import scala.collection.mutable
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.SparkContext._
-class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
+class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext {
implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] {
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index dc2db66df60e0..13b415cccb647 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -201,7 +201,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
def newPairRDD = newRDD.map(_ -> 1)
def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
def newBroadcast = sc.broadcast(1 to 100)
- def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = {
+ def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
getAllDependencies(dep.rdd)
@@ -211,8 +211,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
// Get all the shuffle dependencies
val shuffleDeps = getAllDependencies(rdd)
- .filter(_.isInstanceOf[ShuffleDependency[_, _]])
- .map(_.asInstanceOf[ShuffleDependency[_, _]])
+ .filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
+ .map(_.asInstanceOf[ShuffleDependency[_, _, _]])
(rdd, shuffleDeps)
}
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 14ddd6f1ec08f..41c294f727b3c 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark
import org.scalatest.BeforeAndAfter
import org.scalatest.FunSuite
import org.scalatest.concurrent.Timeouts._
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.scalatest.time.{Millis, Span}
import org.apache.spark.SparkContext._
@@ -31,7 +31,7 @@ class NotSerializableClass
class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {}
-class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
+class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter
with LocalSparkContext {
val clusterUrl = "local-cluster[2,1,512]"
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 2c8ef405c944c..a57430e829ced 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -25,7 +25,7 @@ import scala.concurrent.duration._
import scala.concurrent.future
import org.scalatest.{BeforeAndAfter, FunSuite}
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.SparkContext._
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
@@ -35,7 +35,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
* (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers
* in both FIFO and fair scheduling modes.
*/
-class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
+class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
with LocalSparkContext {
override def afterEach() {
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index be6508a40ea61..47112ce66d695 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.SparkContext._
import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
@@ -26,7 +26,7 @@ import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.MutablePair
-class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
+class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
val conf = new SparkConf(loadDefaults = false)
@@ -58,7 +58,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf))
- val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 10)
@@ -97,7 +97,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
.setSerializer(new KryoSerializer(conf))
- val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
@@ -122,7 +122,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// NOTE: The default Java serializer should create zero-sized blocks
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
- val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala
index d6b93f5fedd3b..4161aede1d1d0 100644
--- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala
@@ -18,9 +18,9 @@
package org.apache.spark.deploy
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
-class ClientSuite extends FunSuite with ShouldMatchers {
+class ClientSuite extends FunSuite with Matchers {
test("correctly validates driver jar URL's") {
ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true)
ClientArguments.isValidJarUrl("file://some/path/to/a/jarFile.jar") should be (true)
diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
index bfae32dae0dc5..01ab2d549325c 100644
--- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
@@ -28,6 +28,7 @@ import org.scalatest.FunSuite
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse}
import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo}
import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner}
+import org.apache.spark.SparkConf
class JsonProtocolSuite extends FunSuite {
@@ -116,7 +117,8 @@ class JsonProtocolSuite extends FunSuite {
}
def createExecutorRunner(): ExecutorRunner = {
new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host",
- new File("sparkHome"), new File("workDir"), "akka://worker", ExecutorState.RUNNING)
+ new File("sparkHome"), new File("workDir"), "akka://worker",
+ new SparkConf, ExecutorState.RUNNING)
}
def createDriverRunner(): DriverRunner = {
new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), createDriverDesc(),
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 02427a4a83506..565c53e9529ff 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -25,9 +25,9 @@ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkException, Test
import org.apache.spark.deploy.SparkSubmit._
import org.apache.spark.util.Utils
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
-class SparkSubmitSuite extends FunSuite with ShouldMatchers {
+class SparkSubmitSuite extends FunSuite with Matchers {
def beforeAll() {
System.setProperty("spark.testing", "true")
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
index 8ae387fa0be6f..e5f748d55500d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -22,6 +22,7 @@ import java.io.File
import org.scalatest.FunSuite
import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState}
+import org.apache.spark.SparkConf
class ExecutorRunnerTest extends FunSuite {
test("command includes appId") {
@@ -32,7 +33,7 @@ class ExecutorRunnerTest extends FunSuite {
sparkHome, "appUiUrl")
val appId = "12345-worker321-9876"
val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome.getOrElse(".")),
- f("ooga"), "blah", ExecutorState.RUNNING)
+ f("ooga"), "blah", new SparkConf, ExecutorState.RUNNING)
assert(er.getCommandSeq.last === appId)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 9ddafc451878d..0b9004448a63e 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -30,6 +30,19 @@ import org.apache.spark.SparkContext._
import org.apache.spark.{Partitioner, SharedSparkContext}
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
+ test("aggregateByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 1), (3, 2), (5, 1), (5, 3)), 2)
+
+ val sets = pairs.aggregateByKey(new HashSet[Int]())(_ += _, _ ++= _).collect()
+ assert(sets.size === 3)
+ val valuesFor1 = sets.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1))
+ val valuesFor3 = sets.find(_._1 == 3).get._2
+ assert(valuesFor3.toList.sorted === List(2))
+ val valuesFor5 = sets.find(_._1 == 5).get._2
+ assert(valuesFor5.toList.sorted === List(1, 3))
+ }
+
test("groupByKey") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 55af1666df662..e94a1e76d410c 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -24,7 +24,7 @@ import org.scalatest.FunSuite
import org.apache.spark._
import org.apache.spark.SparkContext._
-import org.apache.spark.rdd._
+import org.apache.spark.util.Utils
class RDDSuite extends FunSuite with SharedSparkContext {
@@ -66,6 +66,13 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}
+ test("serialization") {
+ val empty = new EmptyRDD[Int](sc)
+ val serial = Utils.serialize(empty)
+ val deserial: EmptyRDD[Int] = Utils.deserialize(serial)
+ assert(!deserial.toString().isEmpty())
+ }
+
test("countApproxDistinct") {
def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble
@@ -498,55 +505,56 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
test("takeSample") {
- val data = sc.parallelize(1 to 100, 2)
+ val n = 1000000
+ val data = sc.parallelize(1 to n, 2)
for (num <- List(5, 20, 100)) {
val sample = data.takeSample(withReplacement=false, num=num)
assert(sample.size === num) // Got exactly num elements
assert(sample.toSet.size === num) // Elements are distinct
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.toSet.size === 20) // Elements are distinct
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
- val sample = data.takeSample(withReplacement=false, 200, seed)
+ val sample = data.takeSample(withReplacement=false, 100, seed)
assert(sample.size === 100) // Got only 100 elements
assert(sample.toSet.size === 100) // Elements are distinct
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
{
val sample = data.takeSample(withReplacement=true, num=20)
assert(sample.size === 20) // Got exactly 100 elements
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
{
- val sample = data.takeSample(withReplacement=true, num=100)
- assert(sample.size === 100) // Got exactly 100 elements
+ val sample = data.takeSample(withReplacement=true, num=n)
+ assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
- assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
+ assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")
}
for (seed <- 1 to 5) {
- val sample = data.takeSample(withReplacement=true, 100, seed)
- assert(sample.size === 100) // Got exactly 100 elements
+ val sample = data.takeSample(withReplacement=true, n, seed)
+ assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
- assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+ assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
for (seed <- 1 to 5) {
- val sample = data.takeSample(withReplacement=true, 200, seed)
- assert(sample.size === 200) // Got exactly 200 elements
+ val sample = data.takeSample(withReplacement=true, 2 * n, seed)
+ assert(sample.size === 2 * n) // Got exactly 200 elements
// Chance of getting all distinct elements is still quite low, so test we got < 100
- assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+ assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
index d0619559bb457..656917628f7a8 100644
--- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.rdd
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.{Logging, SharedSparkContext}
import org.apache.spark.SparkContext._
-class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging {
+class SortingSuite extends FunSuite with SharedSparkContext with Matchers with Logging {
test("sortByKey") {
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 5426e578a9ddd..be506e0287a16 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -22,13 +22,13 @@ import java.util.concurrent.Semaphore
import scala.collection.mutable
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.executor.TaskMetrics
-class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
+class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
with BeforeAndAfter with BeforeAndAfterAll {
/** Length of time to wait while draining listener events. */
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 81bd8257bc155..d7dbe5164b7f6 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -28,7 +28,7 @@ import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester}
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._
-import org.scalatest.matchers.ShouldMatchers._
+import org.scalatest.Matchers
import org.scalatest.time.SpanSugar._
import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
@@ -39,7 +39,8 @@ import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, U
import scala.language.implicitConversions
import scala.language.postfixOps
-class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester {
+class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
+ with PrivateMethodTester {
private val conf = new SparkConf(false)
var store: BlockManager = null
var store2: BlockManager = null
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 8c06a2d9aa4ab..91b4c7b0dd962 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -18,14 +18,14 @@
package org.apache.spark.ui.jobs
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.{LocalSparkContext, SparkConf, Success}
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
-class JobProgressListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers {
test("test LRU eviction of stages") {
val conf = new SparkConf()
conf.set("spark.ui.retainedStages", 5.toString)
diff --git a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala
index 63642461e4465..090d48ec921a1 100644
--- a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala
@@ -18,13 +18,13 @@
package org.apache.spark.util
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
/**
*
*/
-class DistributionSuite extends FunSuite with ShouldMatchers {
+class DistributionSuite extends FunSuite with Matchers {
test("summary") {
val d = new Distribution((1 to 100).toArray.map{_.toDouble})
val stats = d.statCounter
diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
new file mode 100644
index 0000000000000..53d7f5c6072e6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io._
+
+import scala.collection.mutable.HashSet
+import scala.reflect._
+
+import org.apache.commons.io.{FileUtils, IOUtils}
+import org.apache.spark.{Logging, SparkConf}
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender}
+
+class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging {
+
+ val testFile = new File("FileAppenderSuite-test-" + System.currentTimeMillis).getAbsoluteFile
+
+ before {
+ cleanup()
+ }
+
+ after {
+ cleanup()
+ }
+
+ test("basic file appender") {
+ val testString = (1 to 1000).mkString(", ")
+ val inputStream = IOUtils.toInputStream(testString)
+ val appender = new FileAppender(inputStream, testFile)
+ inputStream.close()
+ appender.awaitTermination()
+ assert(FileUtils.readFileToString(testFile) === testString)
+ }
+
+ test("rolling file appender - time-based rolling") {
+ // setup input stream and appender
+ val testOutputStream = new PipedOutputStream()
+ val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000)
+ val rolloverIntervalMillis = 100
+ val durationMillis = 1000
+ val numRollovers = durationMillis / rolloverIntervalMillis
+ val textToAppend = (1 to numRollovers).map( _.toString * 10 )
+
+ val appender = new RollingFileAppender(testInputStream, testFile,
+ new TimeBasedRollingPolicy(rolloverIntervalMillis, s"--HH-mm-ss-SSSS", false),
+ new SparkConf(), 10)
+
+ testRolling(appender, testOutputStream, textToAppend, rolloverIntervalMillis)
+ }
+
+ test("rolling file appender - size-based rolling") {
+ // setup input stream and appender
+ val testOutputStream = new PipedOutputStream()
+ val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000)
+ val rolloverSize = 1000
+ val textToAppend = (1 to 3).map( _.toString * 1000 )
+
+ val appender = new RollingFileAppender(testInputStream, testFile,
+ new SizeBasedRollingPolicy(rolloverSize, false), new SparkConf(), 99)
+
+ val files = testRolling(appender, testOutputStream, textToAppend, 0)
+ files.foreach { file =>
+ logInfo(file.toString + ": " + file.length + " bytes")
+ assert(file.length <= rolloverSize)
+ }
+ }
+
+ test("rolling file appender - cleaning") {
+ // setup input stream and appender
+ val testOutputStream = new PipedOutputStream()
+ val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000)
+ val conf = new SparkConf().set(RollingFileAppender.RETAINED_FILES_PROPERTY, "10")
+ val appender = new RollingFileAppender(testInputStream, testFile,
+ new SizeBasedRollingPolicy(1000, false), conf, 10)
+
+ // send data to appender through the input stream, and wait for the data to be written
+ val allGeneratedFiles = new HashSet[String]()
+ val items = (1 to 10).map { _.toString * 10000 }
+ for (i <- 0 until items.size) {
+ testOutputStream.write(items(i).getBytes("UTF8"))
+ testOutputStream.flush()
+ allGeneratedFiles ++= RollingFileAppender.getSortedRolledOverFiles(
+ testFile.getParentFile.toString, testFile.getName).map(_.toString)
+
+ Thread.sleep(10)
+ }
+ testOutputStream.close()
+ appender.awaitTermination()
+ logInfo("Appender closed")
+
+ // verify whether the earliest file has been deleted
+ val rolledOverFiles = allGeneratedFiles.filter { _ != testFile.toString }.toArray.sorted
+ logInfo(s"All rolled over files generated:${rolledOverFiles.size}\n" + rolledOverFiles.mkString("\n"))
+ assert(rolledOverFiles.size > 2)
+ val earliestRolledOverFile = rolledOverFiles.head
+ val existingRolledOverFiles = RollingFileAppender.getSortedRolledOverFiles(
+ testFile.getParentFile.toString, testFile.getName).map(_.toString)
+ logInfo("Existing rolled over files:\n" + existingRolledOverFiles.mkString("\n"))
+ assert(!existingRolledOverFiles.toSet.contains(earliestRolledOverFile))
+ }
+
+ test("file appender selection") {
+ // Test whether FileAppender.apply() returns the right type of the FileAppender based
+ // on SparkConf settings.
+
+ def testAppenderSelection[ExpectedAppender: ClassTag, ExpectedRollingPolicy](
+ properties: Seq[(String, String)], expectedRollingPolicyParam: Long = -1): FileAppender = {
+
+ // Set spark conf properties
+ val conf = new SparkConf
+ properties.foreach { p =>
+ conf.set(p._1, p._2)
+ }
+
+ // Create and test file appender
+ val inputStream = new PipedInputStream(new PipedOutputStream())
+ val appender = FileAppender(inputStream, new File("stdout"), conf)
+ assert(appender.isInstanceOf[ExpectedAppender])
+ assert(appender.getClass.getSimpleName ===
+ classTag[ExpectedAppender].runtimeClass.getSimpleName)
+ if (appender.isInstanceOf[RollingFileAppender]) {
+ val rollingPolicy = appender.asInstanceOf[RollingFileAppender].rollingPolicy
+ rollingPolicy.isInstanceOf[ExpectedRollingPolicy]
+ val policyParam = if (rollingPolicy.isInstanceOf[TimeBasedRollingPolicy]) {
+ rollingPolicy.asInstanceOf[TimeBasedRollingPolicy].rolloverIntervalMillis
+ } else {
+ rollingPolicy.asInstanceOf[SizeBasedRollingPolicy].rolloverSizeBytes
+ }
+ assert(policyParam === expectedRollingPolicyParam)
+ }
+ appender
+ }
+
+ import RollingFileAppender._
+
+ def rollingStrategy(strategy: String) = Seq(STRATEGY_PROPERTY -> strategy)
+ def rollingSize(size: String) = Seq(SIZE_PROPERTY -> size)
+ def rollingInterval(interval: String) = Seq(INTERVAL_PROPERTY -> interval)
+
+ val msInDay = 24 * 60 * 60 * 1000L
+ val msInHour = 60 * 60 * 1000L
+ val msInMinute = 60 * 1000L
+
+ // test no strategy -> no rolling
+ testAppenderSelection[FileAppender, Any](Seq.empty)
+
+ // test time based rolling strategy
+ testAppenderSelection[RollingFileAppender, Any](rollingStrategy("time"), msInDay)
+ testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy](
+ rollingStrategy("time") ++ rollingInterval("daily"), msInDay)
+ testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy](
+ rollingStrategy("time") ++ rollingInterval("hourly"), msInHour)
+ testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy](
+ rollingStrategy("time") ++ rollingInterval("minutely"), msInMinute)
+ testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy](
+ rollingStrategy("time") ++ rollingInterval("123456789"), 123456789 * 1000L)
+ testAppenderSelection[FileAppender, Any](
+ rollingStrategy("time") ++ rollingInterval("xyz"))
+
+ // test size based rolling strategy
+ testAppenderSelection[RollingFileAppender, SizeBasedRollingPolicy](
+ rollingStrategy("size") ++ rollingSize("123456789"), 123456789)
+ testAppenderSelection[FileAppender, Any](rollingSize("xyz"))
+
+ // test illegal strategy
+ testAppenderSelection[FileAppender, Any](rollingStrategy("xyz"))
+ }
+
+ /**
+ * Run the rolling file appender with data and see whether all the data was written correctly
+ * across rolled over files.
+ */
+ def testRolling(
+ appender: FileAppender,
+ outputStream: OutputStream,
+ textToAppend: Seq[String],
+ sleepTimeBetweenTexts: Long
+ ): Seq[File] = {
+ // send data to appender through the input stream, and wait for the data to be written
+ val expectedText = textToAppend.mkString("")
+ for (i <- 0 until textToAppend.size) {
+ outputStream.write(textToAppend(i).getBytes("UTF8"))
+ outputStream.flush()
+ Thread.sleep(sleepTimeBetweenTexts)
+ }
+ logInfo("Data sent to appender")
+ outputStream.close()
+ appender.awaitTermination()
+ logInfo("Appender closed")
+
+ // verify whether all the data written to rolled over files is same as expected
+ val generatedFiles = RollingFileAppender.getSortedRolledOverFiles(
+ testFile.getParentFile.toString, testFile.getName)
+ logInfo("Filtered files: \n" + generatedFiles.mkString("\n"))
+ assert(generatedFiles.size > 1)
+ val allText = generatedFiles.map { file =>
+ FileUtils.readFileToString(file)
+ }.mkString("")
+ assert(allText === expectedText)
+ generatedFiles
+ }
+
+ /** Delete all the generated rolledover files */
+ def cleanup() {
+ testFile.getParentFile.listFiles.filter { file =>
+ file.getName.startsWith(testFile.getName)
+ }.foreach { _.delete() }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala
index 32d74d0500b72..cf438a3d72a06 100644
--- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala
@@ -22,9 +22,9 @@ import java.util.NoSuchElementException
import scala.collection.mutable.Buffer
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
-class NextIteratorSuite extends FunSuite with ShouldMatchers {
+class NextIteratorSuite extends FunSuite with Matchers {
test("one iteration") {
val i = new StubIterator(Buffer(1))
i.hasNext should be === true
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 0aad882ed76a8..1ee936bc78f49 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -140,6 +140,38 @@ class UtilsSuite extends FunSuite {
Utils.deleteRecursively(tmpDir2)
}
+ test("reading offset bytes across multiple files") {
+ val tmpDir = Files.createTempDir()
+ tmpDir.deleteOnExit()
+ val files = (1 to 3).map(i => new File(tmpDir, i.toString))
+ Files.write("0123456789", files(0), Charsets.UTF_8)
+ Files.write("abcdefghij", files(1), Charsets.UTF_8)
+ Files.write("ABCDEFGHIJ", files(2), Charsets.UTF_8)
+
+ // Read first few bytes in the 1st file
+ assert(Utils.offsetBytes(files, 0, 5) === "01234")
+
+ // Read bytes within the 1st file
+ assert(Utils.offsetBytes(files, 5, 8) === "567")
+
+ // Read bytes across 1st and 2nd file
+ assert(Utils.offsetBytes(files, 8, 18) === "89abcdefgh")
+
+ // Read bytes across 1st, 2nd and 3rd file
+ assert(Utils.offsetBytes(files, 5, 24) === "56789abcdefghijABCD")
+
+ // Read some nonexistent bytes in the beginning
+ assert(Utils.offsetBytes(files, -5, 18) === "0123456789abcdefgh")
+
+ // Read some nonexistent bytes at the end
+ assert(Utils.offsetBytes(files, 18, 35) === "ijABCDEFGHIJ")
+
+ // Read some nonexistent bytes on both ends
+ assert(Utils.offsetBytes(files, -5, 35) === "0123456789abcdefghijABCDEFGHIJ")
+
+ Utils.deleteRecursively(tmpDir)
+ }
+
test("deserialize long value") {
val testval : Long = 9730889947L
val bbuf = ByteBuffer.allocate(8)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
index b024c89d94d33..6a70877356409 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -20,11 +20,11 @@ package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.util.SizeEstimator
-class OpenHashMapSuite extends FunSuite with ShouldMatchers {
+class OpenHashMapSuite extends FunSuite with Matchers {
test("size for specialized, primitive value (int)") {
val capacity = 1024
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
index ff4a98f5dcd4a..68a03e3a0970f 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
@@ -18,11 +18,11 @@
package org.apache.spark.util.collection
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.util.SizeEstimator
-class OpenHashSetSuite extends FunSuite with ShouldMatchers {
+class OpenHashSetSuite extends FunSuite with Matchers {
test("size for specialized, primitive int") {
val loadFactor = 0.7
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
index e3fca173908e9..8c7df7d73dcd3 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
@@ -20,11 +20,11 @@ package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.util.SizeEstimator
-class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers {
+class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers {
test("size for specialized, primitive key, value (int, int)") {
val capacity = 1024
diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
new file mode 100644
index 0000000000000..accfe2e9b7f2a
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.random
+
+import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
+import org.scalatest.FunSuite
+
+class SamplingUtilsSuite extends FunSuite {
+
+ test("computeFraction") {
+ // test that the computed fraction guarantees enough data points
+ // in the sample with a failure rate <= 0.0001
+ val n = 100000
+
+ for (s <- 1 to 15) {
+ val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
+ val poisson = new PoissonDistribution(frac * n)
+ assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
+ }
+ for (s <- List(20, 100, 1000)) {
+ val frac = SamplingUtils.computeFractionForSampleSize(s, n, true)
+ val poisson = new PoissonDistribution(frac * n)
+ assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
+ }
+ for (s <- List(1, 10, 100, 1000)) {
+ val frac = SamplingUtils.computeFractionForSampleSize(s, n, false)
+ val binomial = new BinomialDistribution(n, frac)
+ assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low")
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
index 0865c6386f7cd..e15fd59a5a8bb 100644
--- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
@@ -18,13 +18,13 @@
package org.apache.spark.util.random
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.util.Utils.times
import scala.language.reflectiveCalls
-class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
+class XORShiftRandomSuite extends FunSuite with Matchers {
def fixture = new {
val seed = 1L
diff --git a/dev/mima b/dev/mima
index ab6bd4469b0e8..b68800d6d0173 100755
--- a/dev/mima
+++ b/dev/mima
@@ -23,6 +23,9 @@ set -o pipefail
FWDIR="$(cd `dirname $0`/..; pwd)"
cd $FWDIR
+echo -e "q\n" | sbt/sbt oldDeps/update
+
+export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) -printf "%p:" `
./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore
echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving"
ret_val=$?
@@ -31,5 +34,5 @@ if [ $ret_val != 0 ]; then
echo "NOTE: Exceptions to binary compatibility can be added in project/MimaExcludes.scala"
fi
-rm -f .generated-mima-excludes
+rm -f .generated-mima*
exit $ret_val
diff --git a/docs/configuration.md b/docs/configuration.md
index 71fafa573467f..b84104cc7e653 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -784,6 +784,45 @@ Apart from these, the following properties are also available, and may be useful
higher memory usage in Spark.
+
+ spark.executor.logs.rolling.strategy |
+ (none) |
+
+ Set the strategy of rolling of executor logs. By default it is disabled. It can
+ be set to "time" (time-based rolling) or "size" (size-based rolling). For "time",
+ use spark.executor.logs.rolling.time.interval to set the rolling interval.
+ For "size", use spark.executor.logs.rolling.size.maxBytes to set
+ the maximum file size for rolling.
+ |
+
+
+ spark.executor.logs.rolling.time.interval |
+ daily |
+
+ Set the time interval by which the executor logs will be rolled over.
+ Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or
+ any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles
+ for automatic cleaning of old logs.
+ |
+
+
+ spark.executor.logs.rolling.size.maxBytes |
+ (none) |
+
+ Set the max size of the file by which the executor logs will be rolled over.
+ Rolling is disabled by default. Value is set in terms of bytes.
+ See spark.executor.logs.rolling.maxRetainedFiles
+ for automatic cleaning of old logs.
+ |
+
+
+ spark.executor.logs.rolling.maxRetainedFiles |
+ (none) |
+
+ Sets the number of latest rolling log files that are going to be retained by the system.
+ Older log files will be deleted. Disabled by default.
+ |
+
#### Cluster Managers
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 7989e02dfb732..79784682bfd1b 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -890,6 +890,10 @@ for details.
reduceByKey(func, [numTasks]) |
When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in groupByKey , the number of reduce tasks is configurable through an optional second argument. |
+
+ aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) |
+ When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in groupByKey , the number of reduce tasks is configurable through an optional second argument. |
+
sortByKey([ascending], [numTasks]) |
When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument. |
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index 6eb41e7ba36fb..28e201d279f41 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -50,6 +50,8 @@ object MovieLensALS {
numIterations: Int = 20,
lambda: Double = 1.0,
rank: Int = 10,
+ numUserBlocks: Int = -1,
+ numProductBlocks: Int = -1,
implicitPrefs: Boolean = false)
def main(args: Array[String]) {
@@ -67,8 +69,14 @@ object MovieLensALS {
.text(s"lambda (smoothing constant), default: ${defaultParams.lambda}")
.action((x, c) => c.copy(lambda = x))
opt[Unit]("kryo")
- .text(s"use Kryo serialization")
+ .text("use Kryo serialization")
.action((_, c) => c.copy(kryo = true))
+ opt[Int]("numUserBlocks")
+ .text(s"number of user blocks, default: ${defaultParams.numUserBlocks} (auto)")
+ .action((x, c) => c.copy(numUserBlocks = x))
+ opt[Int]("numProductBlocks")
+ .text(s"number of product blocks, default: ${defaultParams.numProductBlocks} (auto)")
+ .action((x, c) => c.copy(numProductBlocks = x))
opt[Unit]("implicitPrefs")
.text("use implicit preference")
.action((_, c) => c.copy(implicitPrefs = true))
@@ -160,6 +168,8 @@ object MovieLensALS {
.setIterations(params.numIterations)
.setLambda(params.lambda)
.setImplicitPrefs(params.implicitPrefs)
+ .setUserBlocks(params.numUserBlocks)
+ .setProductBlocks(params.numProductBlocks)
.run(training)
val rmse = computeRmse(model, test, params.implicitPrefs)
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
index 5be33f1d5c428..ed35e34ad45ab 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
@@ -71,12 +71,12 @@ class SparkFlumeEvent() extends Externalizable {
for (i <- 0 until numHeaders) {
val keyLength = in.readInt()
val keyBuff = new Array[Byte](keyLength)
- in.read(keyBuff)
+ in.readFully(keyBuff)
val key : String = Utils.deserialize(keyBuff)
val valLength = in.readInt()
val valBuff = new Array[Byte](valLength)
- in.read(valBuff)
+ in.readFully(valBuff)
val value : String = Utils.deserialize(valBuff)
headers.put(key, value)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index d743bd7dd1825..cc56fd6ef28d6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -61,7 +61,7 @@ private[recommendation] case class InLinkBlock(
* A more compact class to represent a rating than Tuple3[Int, Int, Double].
*/
@Experimental
-case class Rating(val user: Int, val product: Int, val rating: Double)
+case class Rating(user: Int, product: Int, rating: Double)
/**
* Alternating Least Squares matrix factorization.
@@ -93,7 +93,8 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
* preferences rather than explicit ratings given to items.
*/
class ALS private (
- private var numBlocks: Int,
+ private var numUserBlocks: Int,
+ private var numProductBlocks: Int,
private var rank: Int,
private var iterations: Int,
private var lambda: Double,
@@ -106,14 +107,31 @@ class ALS private (
* Constructs an ALS instance with default parameters: {numBlocks: -1, rank: 10, iterations: 10,
* lambda: 0.01, implicitPrefs: false, alpha: 1.0}.
*/
- def this() = this(-1, 10, 10, 0.01, false, 1.0)
+ def this() = this(-1, -1, 10, 10, 0.01, false, 1.0)
/**
- * Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured
- * number of blocks. Default: -1.
+ * Set the number of blocks for both user blocks and product blocks to parallelize the computation
+ * into; pass -1 for an auto-configured number of blocks. Default: -1.
*/
def setBlocks(numBlocks: Int): ALS = {
- this.numBlocks = numBlocks
+ this.numUserBlocks = numBlocks
+ this.numProductBlocks = numBlocks
+ this
+ }
+
+ /**
+ * Set the number of user blocks to parallelize the computation.
+ */
+ def setUserBlocks(numUserBlocks: Int): ALS = {
+ this.numUserBlocks = numUserBlocks
+ this
+ }
+
+ /**
+ * Set the number of product blocks to parallelize the computation.
+ */
+ def setProductBlocks(numProductBlocks: Int): ALS = {
+ this.numProductBlocks = numProductBlocks
this
}
@@ -176,31 +194,32 @@ class ALS private (
def run(ratings: RDD[Rating]): MatrixFactorizationModel = {
val sc = ratings.context
- val numBlocks = if (this.numBlocks == -1) {
+ val numUserBlocks = if (this.numUserBlocks == -1) {
math.max(sc.defaultParallelism, ratings.partitions.size / 2)
} else {
- this.numBlocks
+ this.numUserBlocks
}
-
- val partitioner = new Partitioner {
- val numPartitions = numBlocks
-
- def getPartition(x: Any): Int = {
- Utils.nonNegativeMod(byteswap32(x.asInstanceOf[Int]), numPartitions)
- }
+ val numProductBlocks = if (this.numProductBlocks == -1) {
+ math.max(sc.defaultParallelism, ratings.partitions.size / 2)
+ } else {
+ this.numProductBlocks
}
- val ratingsByUserBlock = ratings.map{ rating =>
- (partitioner.getPartition(rating.user), rating)
+ val userPartitioner = new ALSPartitioner(numUserBlocks)
+ val productPartitioner = new ALSPartitioner(numProductBlocks)
+
+ val ratingsByUserBlock = ratings.map { rating =>
+ (userPartitioner.getPartition(rating.user), rating)
}
- val ratingsByProductBlock = ratings.map{ rating =>
- (partitioner.getPartition(rating.product),
+ val ratingsByProductBlock = ratings.map { rating =>
+ (productPartitioner.getPartition(rating.product),
Rating(rating.product, rating.user, rating.rating))
}
- val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock, partitioner)
+ val (userInLinks, userOutLinks) =
+ makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, productPartitioner)
val (productInLinks, productOutLinks) =
- makeLinkRDDs(numBlocks, ratingsByProductBlock, partitioner)
+ makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, userPartitioner)
userInLinks.setName("userInLinks")
userOutLinks.setName("userOutLinks")
productInLinks.setName("productInLinks")
@@ -232,27 +251,27 @@ class ALS private (
users.setName(s"users-$iter").persist()
val YtY = Some(sc.broadcast(computeYtY(users)))
val previousProducts = products
- products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
- alpha, YtY)
+ products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks,
+ userPartitioner, rank, lambda, alpha, YtY)
previousProducts.unpersist()
logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
products.setName(s"products-$iter").persist()
val XtX = Some(sc.broadcast(computeYtY(products)))
val previousUsers = users
- users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
- alpha, XtX)
+ users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks,
+ productPartitioner, rank, lambda, alpha, XtX)
previousUsers.unpersist()
}
} else {
for (iter <- 1 to iterations) {
// perform ALS update
logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations))
- products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
- alpha, YtY = None)
+ products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks,
+ userPartitioner, rank, lambda, alpha, YtY = None)
products.setName(s"products-$iter")
logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
- users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
- alpha, YtY = None)
+ users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks,
+ productPartitioner, rank, lambda, alpha, YtY = None)
users.setName(s"users-$iter")
}
}
@@ -340,9 +359,10 @@ class ALS private (
/**
* Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs
*/
- private def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])],
- outLinks: RDD[(Int, OutLinkBlock)]) = {
- blockedFactors.join(outLinks).flatMap{ case (b, (factors, outLinkBlock)) =>
+ private def unblockFactors(
+ blockedFactors: RDD[(Int, Array[Array[Double]])],
+ outLinks: RDD[(Int, OutLinkBlock)]): RDD[(Int, Array[Double])] = {
+ blockedFactors.join(outLinks).flatMap { case (b, (factors, outLinkBlock)) =>
for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
}
}
@@ -351,14 +371,14 @@ class ALS private (
* Make the out-links table for a block of the users (or products) dataset given the list of
* (user, product, rating) values for the users in that block (or the opposite for products).
*/
- private def makeOutLinkBlock(numBlocks: Int, ratings: Array[Rating],
- partitioner: Partitioner): OutLinkBlock = {
+ private def makeOutLinkBlock(numProductBlocks: Int, ratings: Array[Rating],
+ productPartitioner: Partitioner): OutLinkBlock = {
val userIds = ratings.map(_.user).distinct.sorted
val numUsers = userIds.length
val userIdToPos = userIds.zipWithIndex.toMap
- val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks))
+ val shouldSend = Array.fill(numUsers)(new BitSet(numProductBlocks))
for (r <- ratings) {
- shouldSend(userIdToPos(r.user))(partitioner.getPartition(r.product)) = true
+ shouldSend(userIdToPos(r.user))(productPartitioner.getPartition(r.product)) = true
}
OutLinkBlock(userIds, shouldSend)
}
@@ -367,18 +387,17 @@ class ALS private (
* Make the in-links table for a block of the users (or products) dataset given a list of
* (user, product, rating) values for the users in that block (or the opposite for products).
*/
- private def makeInLinkBlock(numBlocks: Int, ratings: Array[Rating],
- partitioner: Partitioner): InLinkBlock = {
+ private def makeInLinkBlock(numProductBlocks: Int, ratings: Array[Rating],
+ productPartitioner: Partitioner): InLinkBlock = {
val userIds = ratings.map(_.user).distinct.sorted
- val numUsers = userIds.length
val userIdToPos = userIds.zipWithIndex.toMap
// Split out our ratings by product block
- val blockRatings = Array.fill(numBlocks)(new ArrayBuffer[Rating])
+ val blockRatings = Array.fill(numProductBlocks)(new ArrayBuffer[Rating])
for (r <- ratings) {
- blockRatings(partitioner.getPartition(r.product)) += r
+ blockRatings(productPartitioner.getPartition(r.product)) += r
}
- val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numBlocks)
- for (productBlock <- 0 until numBlocks) {
+ val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numProductBlocks)
+ for (productBlock <- 0 until numProductBlocks) {
// Create an array of (product, Seq(Rating)) ratings
val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray
// Sort them by product ID
@@ -400,14 +419,16 @@ class ALS private (
* the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid
* having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it.
*/
- private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)], partitioner: Partitioner)
- : (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) =
- {
- val grouped = ratings.partitionBy(new HashPartitioner(numBlocks))
+ private def makeLinkRDDs(
+ numUserBlocks: Int,
+ numProductBlocks: Int,
+ ratingsByUserBlock: RDD[(Int, Rating)],
+ productPartitioner: Partitioner): (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = {
+ val grouped = ratingsByUserBlock.partitionBy(new HashPartitioner(numUserBlocks))
val links = grouped.mapPartitionsWithIndex((blockId, elements) => {
- val ratings = elements.map{_._2}.toArray
- val inLinkBlock = makeInLinkBlock(numBlocks, ratings, partitioner)
- val outLinkBlock = makeOutLinkBlock(numBlocks, ratings, partitioner)
+ val ratings = elements.map(_._2).toArray
+ val inLinkBlock = makeInLinkBlock(numProductBlocks, ratings, productPartitioner)
+ val outLinkBlock = makeOutLinkBlock(numProductBlocks, ratings, productPartitioner)
Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
}, true)
val inLinks = links.mapValues(_._1)
@@ -439,26 +460,24 @@ class ALS private (
* It returns an RDD of new feature vectors for each user block.
*/
private def updateFeatures(
+ numUserBlocks: Int,
products: RDD[(Int, Array[Array[Double]])],
productOutLinks: RDD[(Int, OutLinkBlock)],
userInLinks: RDD[(Int, InLinkBlock)],
- partitioner: Partitioner,
+ productPartitioner: Partitioner,
rank: Int,
lambda: Double,
alpha: Double,
- YtY: Option[Broadcast[DoubleMatrix]])
- : RDD[(Int, Array[Array[Double]])] =
- {
- val numBlocks = products.partitions.size
+ YtY: Option[Broadcast[DoubleMatrix]]): RDD[(Int, Array[Array[Double]])] = {
productOutLinks.join(products).flatMap { case (bid, (outLinkBlock, factors)) =>
- val toSend = Array.fill(numBlocks)(new ArrayBuffer[Array[Double]])
- for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numBlocks) {
+ val toSend = Array.fill(numUserBlocks)(new ArrayBuffer[Array[Double]])
+ for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numUserBlocks) {
if (outLinkBlock.shouldSend(p)(userBlock)) {
toSend(userBlock) += factors(p)
}
}
toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) }
- }.groupByKey(partitioner)
+ }.groupByKey(productPartitioner)
.join(userInLinks)
.mapValues{ case (messages, inLinkBlock) =>
updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY)
@@ -475,7 +494,7 @@ class ALS private (
{
// Sort the incoming block factor messages by block ID and make them an array
val blockFactors = messages.toSeq.sortBy(_._1).map(_._2).toArray // Array[Array[Double]]
- val numBlocks = blockFactors.length
+ val numProductBlocks = blockFactors.length
val numUsers = inLinkBlock.elementIds.length
// We'll sum up the XtXes using vectors that represent only the lower-triangular part, since
@@ -488,9 +507,12 @@ class ALS private (
val tempXtX = DoubleMatrix.zeros(triangleSize)
val fullXtX = DoubleMatrix.zeros(rank, rank)
+ // Count the number of ratings each user gives to provide user-specific regularization
+ val numRatings = Array.fill(numUsers)(0)
+
// Compute the XtX and Xy values for each user by adding products it rated in each product
// block
- for (productBlock <- 0 until numBlocks) {
+ for (productBlock <- 0 until numProductBlocks) {
var p = 0
while (p < blockFactors(productBlock).length) {
val x = wrapDoubleArray(blockFactors(productBlock)(p))
@@ -500,6 +522,7 @@ class ALS private (
if (implicitPrefs) {
var i = 0
while (i < us.length) {
+ numRatings(us(i)) += 1
// Extension to the original paper to handle rs(i) < 0. confidence is a function
// of |rs(i)| instead so that it is never negative:
val confidence = 1 + alpha * abs(rs(i))
@@ -515,6 +538,7 @@ class ALS private (
} else {
var i = 0
while (i < us.length) {
+ numRatings(us(i)) += 1
userXtX(us(i)).addi(tempXtX)
SimpleBlas.axpy(rs(i), x, userXy(us(i)))
i += 1
@@ -531,9 +555,10 @@ class ALS private (
// Compute the full XtX matrix from the lower-triangular part we got above
fillFullMatrix(userXtX(index), fullXtX)
// Add regularization
+ val regParam = numRatings(index) * lambda
var i = 0
while (i < rank) {
- fullXtX.data(i * rank + i) += lambda
+ fullXtX.data(i * rank + i) += regParam
i += 1
}
// Solve the resulting matrix, which is symmetric and positive-definite
@@ -579,6 +604,23 @@ class ALS private (
}
}
+/**
+ * Partitioner for ALS.
+ */
+private[recommendation] class ALSPartitioner(override val numPartitions: Int) extends Partitioner {
+ override def getPartition(key: Any): Int = {
+ Utils.nonNegativeMod(byteswap32(key.asInstanceOf[Int]), numPartitions)
+ }
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case p: ALSPartitioner =>
+ this.numPartitions == p.numPartitions
+ case _ =>
+ false
+ }
+ }
+}
/**
* Top-level methods for calling Alternating Least Squares (ALS) matrix factorization.
@@ -606,7 +648,7 @@ object ALS {
blocks: Int,
seed: Long
): MatrixFactorizationModel = {
- new ALS(blocks, rank, iterations, lambda, false, 1.0, seed).run(ratings)
+ new ALS(blocks, blocks, rank, iterations, lambda, false, 1.0, seed).run(ratings)
}
/**
@@ -629,7 +671,7 @@ object ALS {
lambda: Double,
blocks: Int
): MatrixFactorizationModel = {
- new ALS(blocks, rank, iterations, lambda, false, 1.0).run(ratings)
+ new ALS(blocks, blocks, rank, iterations, lambda, false, 1.0).run(ratings)
}
/**
@@ -689,7 +731,7 @@ object ALS {
alpha: Double,
seed: Long
): MatrixFactorizationModel = {
- new ALS(blocks, rank, iterations, lambda, true, alpha, seed).run(ratings)
+ new ALS(blocks, blocks, rank, iterations, lambda, true, alpha, seed).run(ratings)
}
/**
@@ -714,7 +756,7 @@ object ALS {
blocks: Int,
alpha: Double
): MatrixFactorizationModel = {
- new ALS(blocks, rank, iterations, lambda, true, alpha).run(ratings)
+ new ALS(blocks, blocks, rank, iterations, lambda, true, alpha).run(ratings)
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 4d7b984e3ec29..44b757b6a1fb7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -21,7 +21,7 @@ import scala.util.Random
import scala.collection.JavaConversions._
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
@@ -56,7 +56,7 @@ object LogisticRegressionSuite {
}
}
-class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
prediction != expected.label
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index 8a16284118cf7..951b4f7c6e6f4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -21,7 +21,7 @@ import scala.util.Random
import scala.collection.JavaConversions._
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
@@ -61,7 +61,7 @@ object GradientDescentSuite {
}
}
-class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers {
test("Assert the loss is decreasing.") {
val nPoints = 10000
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index 820eca9b1bf65..4b1850659a18e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -18,13 +18,13 @@
package org.apache.spark.mllib.optimization
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
-class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
val nPoints = 10000
val A = 2.0
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 37c9b9d085841..81bebec8c7a39 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -121,6 +121,10 @@ class ALSSuite extends FunSuite with LocalSparkContext {
testALS(100, 200, 2, 15, 0.7, 0.4, true, false, true)
}
+ test("rank-2 matrices with different user and product blocks") {
+ testALS(100, 200, 2, 15, 0.7, 0.4, numUserBlocks = 4, numProductBlocks = 2)
+ }
+
test("pseudorandomness") {
val ratings = sc.parallelize(ALSSuite.generateRatings(10, 20, 5, 0.5, false, false)._1, 2)
val model11 = ALS.train(ratings, 5, 1, 1.0, 2, 1)
@@ -153,35 +157,52 @@ class ALSSuite extends FunSuite with LocalSparkContext {
}
test("NNALS, rank 2") {
- testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, false)
+ testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false)
}
/**
* Test if we can correctly factorize R = U * P where U and P are of known rank.
*
- * @param users number of users
- * @param products number of products
- * @param features number of features (rank of problem)
- * @param iterations number of iterations to run
- * @param samplingRate what fraction of the user-product pairs are known
+ * @param users number of users
+ * @param products number of products
+ * @param features number of features (rank of problem)
+ * @param iterations number of iterations to run
+ * @param samplingRate what fraction of the user-product pairs are known
* @param matchThreshold max difference allowed to consider a predicted rating correct
- * @param implicitPrefs flag to test implicit feedback
- * @param bulkPredict flag to test bulk prediciton
+ * @param implicitPrefs flag to test implicit feedback
+ * @param bulkPredict flag to test bulk prediciton
* @param negativeWeights whether the generated data can contain negative values
- * @param numBlocks number of blocks to partition users and products into
+ * @param numUserBlocks number of user blocks to partition users into
+ * @param numProductBlocks number of product blocks to partition products into
* @param negativeFactors whether the generated user/product factors can have negative entries
*/
- def testALS(users: Int, products: Int, features: Int, iterations: Int,
- samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false,
- bulkPredict: Boolean = false, negativeWeights: Boolean = false, numBlocks: Int = -1,
- negativeFactors: Boolean = true)
- {
+ def testALS(
+ users: Int,
+ products: Int,
+ features: Int,
+ iterations: Int,
+ samplingRate: Double,
+ matchThreshold: Double,
+ implicitPrefs: Boolean = false,
+ bulkPredict: Boolean = false,
+ negativeWeights: Boolean = false,
+ numUserBlocks: Int = -1,
+ numProductBlocks: Int = -1,
+ negativeFactors: Boolean = true) {
val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
features, samplingRate, implicitPrefs, negativeWeights, negativeFactors)
- val model = (new ALS().setBlocks(numBlocks).setRank(features).setIterations(iterations)
- .setAlpha(1.0).setImplicitPrefs(implicitPrefs).setLambda(0.01).setSeed(0L)
- .setNonnegative(!negativeFactors).run(sc.parallelize(sampledRatings)))
+ val model = new ALS()
+ .setUserBlocks(numUserBlocks)
+ .setProductBlocks(numProductBlocks)
+ .setRank(features)
+ .setIterations(iterations)
+ .setAlpha(1.0)
+ .setImplicitPrefs(implicitPrefs)
+ .setLambda(0.01)
+ .setSeed(0L)
+ .setNonnegative(!negativeFactors)
+ .run(sc.parallelize(sampledRatings))
val predictedU = new DoubleMatrix(users, features)
for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) {
@@ -208,8 +229,9 @@ class ALSSuite extends FunSuite with LocalSparkContext {
val prediction = predictedRatings.get(u, p)
val correct = trueRatings.get(u, p)
if (math.abs(prediction - correct) > matchThreshold) {
- fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
- u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP))
+ fail(("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s")
+ .format(u, p, correct, prediction, trueRatings, predictedRatings, predictedU,
+ predictedP))
}
}
} else {
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index 1477809943573..bb2d73741c3bf 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -15,16 +15,26 @@
* limitations under the License.
*/
-import com.typesafe.tools.mima.core.{MissingTypesProblem, MissingClassProblem, ProblemFilters}
+import com.typesafe.tools.mima.core._
+import com.typesafe.tools.mima.core.MissingClassProblem
+import com.typesafe.tools.mima.core.MissingTypesProblem
import com.typesafe.tools.mima.core.ProblemFilters._
import com.typesafe.tools.mima.plugin.MimaKeys.{binaryIssueFilters, previousArtifact}
import com.typesafe.tools.mima.plugin.MimaPlugin.mimaDefaultSettings
import sbt._
object MimaBuild {
+
+ def excludeMember(fullName: String) = Seq(
+ ProblemFilters.exclude[MissingMethodProblem](fullName),
+ ProblemFilters.exclude[MissingFieldProblem](fullName),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem](fullName),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](fullName),
+ ProblemFilters.exclude[IncompatibleFieldTypeProblem](fullName)
+ )
+
// Exclude a single class and its corresponding object
- def excludeClass(className: String) = {
- Seq(
+ def excludeClass(className: String) = Seq(
excludePackage(className),
ProblemFilters.exclude[MissingClassProblem](className),
ProblemFilters.exclude[MissingTypesProblem](className),
@@ -32,7 +42,7 @@ object MimaBuild {
ProblemFilters.exclude[MissingClassProblem](className + "$"),
ProblemFilters.exclude[MissingTypesProblem](className + "$")
)
- }
+
// Exclude a Spark class, that is in the package org.apache.spark
def excludeSparkClass(className: String) = {
excludeClass("org.apache.spark." + className)
@@ -49,20 +59,25 @@ object MimaBuild {
val defaultExcludes = Seq()
// Read package-private excludes from file
- val excludeFilePath = (base.getAbsolutePath + "/.generated-mima-excludes")
- val excludeFile = file(excludeFilePath)
+ val classExcludeFilePath = file(base.getAbsolutePath + "/.generated-mima-class-excludes")
+ val memberExcludeFilePath = file(base.getAbsolutePath + "/.generated-mima-member-excludes")
+
val ignoredClasses: Seq[String] =
- if (!excludeFile.exists()) {
+ if (!classExcludeFilePath.exists()) {
Seq()
} else {
- IO.read(excludeFile).split("\n")
+ IO.read(classExcludeFilePath).split("\n")
}
+ val ignoredMembers: Seq[String] =
+ if (!memberExcludeFilePath.exists()) {
+ Seq()
+ } else {
+ IO.read(memberExcludeFilePath).split("\n")
+ }
-
- val externalExcludeFileClasses = ignoredClasses.flatMap(excludeClass)
-
- defaultExcludes ++ externalExcludeFileClasses ++ MimaExcludes.excludes
+ defaultExcludes ++ ignoredClasses.flatMap(excludeClass) ++
+ ignoredMembers.flatMap(excludeMember) ++ MimaExcludes.excludes
}
def mimaSettings(sparkHome: File) = mimaDefaultSettings ++ Seq(
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index dd7efceb23c96..042fdfcc47261 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -52,11 +52,27 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"),
ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1")
+ "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.storage.MemoryStore.Entry"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$"
+ + "createZero$1")
+ ) ++
+ Seq( // Ignore some private methods in ALS.
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"),
+ ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments.
+ "org.apache.spark.mllib.recommendation.ALS.this"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7")
) ++
MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++
MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++
- MimaBuild.excludeSparkClass("util.SerializableHyperLogLog")
+ MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++
+ MimaBuild.excludeSparkClass("storage.Values") ++
+ MimaBuild.excludeSparkClass("storage.Entry") ++
+ MimaBuild.excludeSparkClass("storage.MemoryStore$Entry")
case v if v.startsWith("1.0") =>
Seq(
MimaBuild.excludeSparkPackage("api.java"),
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 069913dbaac56..2d60a44f04f6f 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -59,6 +59,10 @@ object SparkBuild extends Build {
lazy val core = Project("core", file("core"), settings = coreSettings)
+ /** Following project only exists to pull previous artifacts of Spark for generating
+ Mima ignores. For more information see: SPARK 2071 */
+ lazy val oldDeps = Project("oldDeps", file("dev"), settings = oldDepsSettings)
+
def replDependencies = Seq[ProjectReference](core, graphx, bagel, mllib, sql) ++ maybeHiveRef
lazy val repl = Project("repl", file("repl"), settings = replSettings)
@@ -86,7 +90,16 @@ object SparkBuild extends Build {
lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings)
.dependsOn(core, graphx, bagel, mllib, streaming, repl, sql) dependsOn(maybeYarn: _*) dependsOn(maybeHive: _*) dependsOn(maybeGanglia: _*)
- lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects")
+ lazy val assembleDepsTask = TaskKey[Unit]("assemble-deps")
+ lazy val assembleDeps = assembleDepsTask := {
+ println()
+ println("**** NOTE ****")
+ println("'sbt/sbt assemble-deps' is no longer supported.")
+ println("Instead create a normal assembly and:")
+ println(" export SPARK_PREPEND_CLASSES=1 (toggle on)")
+ println(" unset SPARK_PREPEND_CLASSES (toggle off)")
+ println()
+ }
// A configuration to set an alternative publishLocalConfiguration
lazy val MavenCompile = config("m2r") extend(Compile)
@@ -336,6 +349,7 @@ object SparkBuild extends Build {
libraryDependencies ++= Seq(
"com.google.guava" % "guava" % "14.0.1",
"org.apache.commons" % "commons-lang3" % "3.3.2",
+ "org.apache.commons" % "commons-math3" % "3.3" % "test",
"com.google.code.findbugs" % "jsr305" % "1.3.9",
"log4j" % "log4j" % "1.2.17",
"org.slf4j" % "slf4j-api" % slf4jVersion,
@@ -369,6 +383,7 @@ object SparkBuild extends Build {
"net.sf.py4j" % "py4j" % "0.8.1"
),
libraryDependencies ++= maybeAvro,
+ assembleDeps,
previousArtifact := sparkPreviousArtifact("spark-core")
)
@@ -580,9 +595,7 @@ object SparkBuild extends Build {
def assemblyProjSettings = sharedSettings ++ Seq(
name := "spark-assembly",
- assembleDeps in Compile <<= (packageProjects.map(packageBin in Compile in _) ++ Seq(packageDependency in Compile)).dependOn,
- jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" },
- jarName in packageDependency <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" }
+ jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }
) ++ assemblySettings ++ extraAssemblySettings
def extraAssemblySettings() = Seq(
@@ -598,6 +611,17 @@ object SparkBuild extends Build {
}
)
+ def oldDepsSettings() = Defaults.defaultSettings ++ Seq(
+ name := "old-deps",
+ scalaVersion := "2.10.4",
+ retrieveManaged := true,
+ retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
+ libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq",
+ "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter",
+ "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx",
+ "spark-core").map(sparkPreviousArtifact(_).get intransitive())
+ )
+
def twitterSettings() = sharedSettings ++ Seq(
name := "spark-streaming-twitter",
previousArtifact := sparkPreviousArtifact("spark-streaming-twitter"),
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index a411a5d5914e0..e609b60a0f968 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -454,7 +454,7 @@ def _squared_distance(v1, v2):
v2 = _convert_vector(v2)
if type(v1) == ndarray and type(v2) == ndarray:
diff = v1 - v2
- return diff.dot(diff)
+ return numpy.dot(diff, diff)
elif type(v1) == ndarray:
return v2.squared_distance(v1)
else:
@@ -469,10 +469,12 @@ def _dot(vec, target):
calling numpy.dot of the two vectors, but for SciPy ones, we
have to transpose them because they're column vectors.
"""
- if type(vec) == ndarray or type(vec) == SparseVector:
+ if type(vec) == ndarray:
+ return numpy.dot(vec, target)
+ elif type(vec) == SparseVector:
return vec.dot(target)
elif type(vec) == list:
- return _convert_vector(vec).dot(target)
+ return numpy.dot(_convert_vector(vec), target)
else:
return vec.transpose().dot(target)[0]
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 9c69c79236edc..ddd22850a819c 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -31,6 +31,7 @@
import warnings
import heapq
from random import Random
+from math import sqrt, log
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
@@ -202,9 +203,9 @@ def cache(self):
def persist(self, storageLevel):
"""
- Set this RDD's storage level to persist its values across operations after the first time
- it is computed. This can only be used to assign a new storage level if the RDD does not
- have a storage level set yet.
+ Set this RDD's storage level to persist its values across operations
+ after the first time it is computed. This can only be used to assign
+ a new storage level if the RDD does not have a storage level set yet.
"""
self.is_cached = True
javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
@@ -213,7 +214,8 @@ def persist(self, storageLevel):
def unpersist(self):
"""
- Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ Mark the RDD as non-persistent, and remove all blocks for it from
+ memory and disk.
"""
self.is_cached = False
self._jrdd.unpersist()
@@ -357,48 +359,87 @@ def sample(self, withReplacement, fraction, seed=None):
# this is ported from scala/spark/RDD.scala
def takeSample(self, withReplacement, num, seed=None):
"""
- Return a fixed-size sampled subset of this RDD (currently requires numpy).
+ Return a fixed-size sampled subset of this RDD (currently requires
+ numpy).
- >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP
- [4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
+ >>> rdd = sc.parallelize(range(0, 10))
+ >>> len(rdd.takeSample(True, 20, 1))
+ 20
+ >>> len(rdd.takeSample(False, 5, 2))
+ 5
+ >>> len(rdd.takeSample(False, 15, 3))
+ 10
"""
+ numStDev = 10.0
+
+ if num < 0:
+ raise ValueError("Sample size cannot be negative.")
+ elif num == 0:
+ return []
- fraction = 0.0
- total = 0
- multiplier = 3.0
initialCount = self.count()
- maxSelected = 0
+ if initialCount == 0:
+ return []
- if (num < 0):
- raise ValueError
+ rand = Random(seed)
- if (initialCount == 0):
- return list()
+ if (not withReplacement) and num >= initialCount:
+ # shuffle current RDD and return
+ samples = self.collect()
+ rand.shuffle(samples)
+ return samples
- if initialCount > sys.maxint - 1:
- maxSelected = sys.maxint - 1
- else:
- maxSelected = initialCount
-
- if num > initialCount and not withReplacement:
- total = maxSelected
- fraction = multiplier * (maxSelected + 1) / initialCount
- else:
- fraction = multiplier * (num + 1) / initialCount
- total = num
+ maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
+ if num > maxSampleSize:
+ raise ValueError("Sample size cannot be greater than %d." % maxSampleSize)
+ fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement)
samples = self.sample(withReplacement, fraction, seed).collect()
# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
- rand = Random(seed)
- while len(samples) < total:
- samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect()
-
- sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint))
- sampler.shuffle(samples)
- return samples[0:total]
+ while len(samples) < num:
+ # TODO: add log warning for when more than one iteration was run
+ seed = rand.randint(0, sys.maxint)
+ samples = self.sample(withReplacement, fraction, seed).collect()
+
+ rand.shuffle(samples)
+
+ return samples[0:num]
+
+ @staticmethod
+ def _computeFractionForSampleSize(sampleSizeLowerBound, total, withReplacement):
+ """
+ Returns a sampling rate that guarantees a sample of
+ size >= sampleSizeLowerBound 99.99% of the time.
+
+ How the sampling rate is determined:
+ Let p = num / total, where num is the sample size and total is the
+ total number of data points in the RDD. We're trying to compute
+ q > p such that
+ - when sampling with replacement, we're drawing each data point
+ with prob_i ~ Pois(q), where we want to guarantee
+ Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to
+ total), i.e. the failure rate of not having a sufficiently large
+ sample < 0.0001. Setting q = p + 5 * sqrt(p/total) is sufficient
+ to guarantee 0.9999 success rate for num > 12, but we need a
+ slightly larger q (9 empirically determined).
+ - when sampling without replacement, we're drawing each data point
+ with prob_i ~ Binomial(total, fraction) and our choice of q
+ guarantees 1-delta, or 0.9999 success rate, where success rate is
+ defined the same as in sampling with replacement.
+ """
+ fraction = float(sampleSizeLowerBound) / total
+ if withReplacement:
+ numStDev = 5
+ if (sampleSizeLowerBound < 12):
+ numStDev = 9
+ return fraction + numStDev * sqrt(fraction / total)
+ else:
+ delta = 0.00005
+ gamma = - log(delta) / total
+ return min(1, fraction + gamma + sqrt(gamma * gamma + 2 * gamma * fraction))
def union(self, other):
"""
@@ -422,8 +463,8 @@ def union(self, other):
def intersection(self, other):
"""
- Return the intersection of this RDD and another one. The output will not
- contain any duplicate elements, even if the input RDDs did.
+ Return the intersection of this RDD and another one. The output will
+ not contain any duplicate elements, even if the input RDDs did.
Note that this method performs a shuffle internally.
@@ -665,8 +706,8 @@ def aggregate(self, zeroValue, seqOp, combOp):
modify C{t2}.
The first function (seqOp) can return a different result type, U, than
- the type of this RDD. Thus, we need one operation for merging a T into an U
- and one operation for merging two U
+ the type of this RDD. Thus, we need one operation for merging a T into
+ an U and one operation for merging two U
>>> seqOp = (lambda x, y: (x[0] + y, x[1] + 1))
>>> combOp = (lambda x, y: (x[0] + y[0], x[1] + y[1]))
@@ -695,7 +736,7 @@ def max(self):
def min(self):
"""
- Find the maximum item in this RDD.
+ Find the minimum item in this RDD.
>>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min()
1.0
@@ -759,8 +800,9 @@ def stdev(self):
def sampleStdev(self):
"""
- Compute the sample standard deviation of this RDD's elements (which corrects for bias in
- estimating the standard deviation by dividing by N-1 instead of N).
+ Compute the sample standard deviation of this RDD's elements (which
+ corrects for bias in estimating the standard deviation by dividing by
+ N-1 instead of N).
>>> sc.parallelize([1, 2, 3]).sampleStdev()
1.0
@@ -769,8 +811,8 @@ def sampleStdev(self):
def sampleVariance(self):
"""
- Compute the sample variance of this RDD's elements (which corrects for bias in
- estimating the variance by dividing by N-1 instead of N).
+ Compute the sample variance of this RDD's elements (which corrects
+ for bias in estimating the variance by dividing by N-1 instead of N).
>>> sc.parallelize([1, 2, 3]).sampleVariance()
1.0
@@ -822,8 +864,8 @@ def merge(a, b):
def takeOrdered(self, num, key=None):
"""
- Get the N elements from a RDD ordered in ascending order or as specified
- by the optional key function.
+ Get the N elements from a RDD ordered in ascending order or as
+ specified by the optional key function.
>>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
[1, 2, 3, 4, 5, 6]
@@ -912,8 +954,9 @@ def first(self):
def saveAsPickleFile(self, path, batchSize=10):
"""
- Save this RDD as a SequenceFile of serialized objects. The serializer used is
- L{pyspark.serializers.PickleSerializer}, default batch size is 10.
+ Save this RDD as a SequenceFile of serialized objects. The serializer
+ used is L{pyspark.serializers.PickleSerializer}, default batch size
+ is 10.
>>> tmpFile = NamedTemporaryFile(delete=True)
>>> tmpFile.close()
@@ -1178,19 +1221,37 @@ def _mergeCombiners(iterator):
combiners[k] = mergeCombiners(combiners[k], v)
return combiners.iteritems()
return shuffled.mapPartitions(_mergeCombiners)
+
+ def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
+ """
+ Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ This function can return a different result type, U, than the type of the values in this RDD,
+ V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ The former operation is used for merging values within a partition, and the latter is used
+ for merging values between partitions. To avoid memory allocation, both of these functions are
+ allowed to modify and return their first argument instead of creating a new U.
+ """
+ def createZero():
+ return copy.deepcopy(zeroValue)
+
+ return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions)
def foldByKey(self, zeroValue, func, numPartitions=None):
"""
- Merge the values for each key using an associative function "func" and a neutral "zeroValue"
- which may be added to the result an arbitrary number of times, and must not change
- the result (e.g., 0 for addition, or 1 for multiplication.).
+ Merge the values for each key using an associative function "func"
+ and a neutral "zeroValue" which may be added to the result an
+ arbitrary number of times, and must not change the result
+ (e.g., 0 for addition, or 1 for multiplication.).
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> from operator import add
>>> rdd.foldByKey(0, add).collect()
[('a', 2), ('b', 1)]
"""
- return self.combineByKey(lambda v: func(zeroValue, v), func, func, numPartitions)
+ def createZero():
+ return copy.deepcopy(zeroValue)
+
+ return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
# TODO: support variant with custom partitioner
@@ -1200,8 +1261,8 @@ def groupByKey(self, numPartitions=None):
Hash-partitions the resulting RDD with into numPartitions partitions.
Note: If you are grouping in order to perform an aggregation (such as a
- sum or average) over each key, using reduceByKey will provide much better
- performance.
+ sum or average) over each key, using reduceByKey will provide much
+ better performance.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
@@ -1261,8 +1322,8 @@ def groupWith(self, other):
def cogroup(self, other, numPartitions=None):
"""
For each key k in C{self} or C{other}, return a resulting RDD that
- contains a tuple with the list of values for that key in C{self} as well
- as C{other}.
+ contains a tuple with the list of values for that key in C{self} as
+ well as C{other}.
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
@@ -1273,8 +1334,8 @@ def cogroup(self, other, numPartitions=None):
def subtractByKey(self, other, numPartitions=None):
"""
- Return each (key, value) pair in C{self} that has no pair with matching key
- in C{other}.
+ Return each (key, value) pair in C{self} that has no pair with matching
+ key in C{other}.
>>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 2)])
>>> y = sc.parallelize([("a", 3), ("c", None)])
@@ -1312,10 +1373,10 @@ def repartition(self, numPartitions):
"""
Return a new RDD that has exactly numPartitions partitions.
- Can increase or decrease the level of parallelism in this RDD. Internally, this uses
- a shuffle to redistribute data.
- If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
- which can avoid performing a shuffle.
+ Can increase or decrease the level of parallelism in this RDD.
+ Internally, this uses a shuffle to redistribute data.
+ If you are decreasing the number of partitions in this RDD, consider
+ using `coalesce`, which can avoid performing a shuffle.
>>> rdd = sc.parallelize([1,2,3,4,5,6,7], 4)
>>> sorted(rdd.glom().collect())
[[1], [2, 3], [4, 5], [6, 7]]
@@ -1340,9 +1401,10 @@ def coalesce(self, numPartitions, shuffle=False):
def zip(self, other):
"""
- Zips this RDD with another one, returning key-value pairs with the first element in each RDD
- second element in each RDD, etc. Assumes that the two RDDs have the same number of
- partitions and the same number of elements in each partition (e.g. one was made through
+ Zips this RDD with another one, returning key-value pairs with the
+ first element in each RDD second element in each RDD, etc. Assumes
+ that the two RDDs have the same number of partitions and the same
+ number of elements in each partition (e.g. one was made through
a map on the other).
>>> x = sc.parallelize(range(0,5))
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index b4e9618cc25b5..960d0a82448aa 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -117,7 +117,7 @@ def parquetFile(self, path):
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.saveAsParquetFile(parquetFile)
>>> srdd2 = sqlCtx.parquetFile(parquetFile)
- >>> srdd.collect() == srdd2.collect()
+ >>> sorted(srdd.collect()) == sorted(srdd2.collect())
True
"""
jschema_rdd = self._ssql_ctx.parquetFile(path)
@@ -141,7 +141,7 @@ def table(self, tableName):
>>> srdd = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
>>> srdd2 = sqlCtx.table("table1")
- >>> srdd.collect() == srdd2.collect()
+ >>> sorted(srdd.collect()) == sorted(srdd2.collect())
True
"""
return SchemaRDD(self._ssql_ctx.table(tableName), self)
@@ -293,7 +293,7 @@ def saveAsParquetFile(self, path):
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.saveAsParquetFile(parquetFile)
>>> srdd2 = sqlCtx.parquetFile(parquetFile)
- >>> srdd2.collect() == srdd.collect()
+ >>> sorted(srdd2.collect()) == sorted(srdd.collect())
True
"""
self._jschema_rdd.saveAsParquetFile(path)
@@ -307,7 +307,7 @@ def registerAsTable(self, name):
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.registerAsTable("test")
>>> srdd2 = sqlCtx.sql("select * from test")
- >>> srdd.collect() == srdd2.collect()
+ >>> sorted(srdd.collect()) == sorted(srdd2.collect())
True
"""
self._jschema_rdd.registerAsTable(name)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 184ee810b861b..c15bb457759ed 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -188,6 +188,21 @@ def test_deleting_input_files(self):
os.unlink(tempFile.name)
self.assertRaises(Exception, lambda: filtered_data.count())
+ def testAggregateByKey(self):
+ data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
+ def seqOp(x, y):
+ x.add(y)
+ return x
+
+ def combOp(x, y):
+ x |= y
+ return x
+
+ sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect())
+ self.assertEqual(3, len(sets))
+ self.assertEqual(set([1]), sets[1])
+ self.assertEqual(set([2]), sets[3])
+ self.assertEqual(set([1, 3]), sets[5])
class TestIO(PySparkTestCase):
diff --git a/python/run-tests b/python/run-tests
index 3b4501178c89f..9282aa47e8375 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -44,7 +44,6 @@ function run_test() {
echo -en "\033[0m" # No color
exit -1
fi
-
}
echo "Running PySpark tests. Output is in python/unit-tests.log."
@@ -55,9 +54,13 @@ run_test "pyspark/conf.py"
if [ -n "$_RUN_SQL_TESTS" ]; then
run_test "pyspark/sql.py"
fi
+# These tests are included in the module-level docs, and so must
+# be handled on a higher level rather than within the python file.
+export PYSPARK_DOC_TEST=1
run_test "pyspark/broadcast.py"
run_test "pyspark/accumulators.py"
run_test "pyspark/serializers.py"
+unset PYSPARK_DOC_TEST
run_test "pyspark/tests.py"
run_test "pyspark/mllib/_common.py"
run_test "pyspark/mllib/classification.py"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 36758f3114e59..46fcfbb9e26ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -111,6 +111,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val APPROXIMATE = Keyword("APPROXIMATE")
protected val AVG = Keyword("AVG")
protected val BY = Keyword("BY")
+ protected val CACHE = Keyword("CACHE")
protected val CAST = Keyword("CAST")
protected val COUNT = Keyword("COUNT")
protected val DESC = Keyword("DESC")
@@ -149,7 +150,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val SEMI = Keyword("SEMI")
protected val STRING = Keyword("STRING")
protected val SUM = Keyword("SUM")
+ protected val TABLE = Keyword("TABLE")
protected val TRUE = Keyword("TRUE")
+ protected val UNCACHE = Keyword("UNCACHE")
protected val UNION = Keyword("UNION")
protected val WHERE = Keyword("WHERE")
@@ -189,7 +192,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } |
UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
)
- | insert
+ | insert | cache
)
protected lazy val select: Parser[LogicalPlan] =
@@ -220,6 +223,11 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
InsertIntoTable(r, Map[String, Option[String]](), s, overwrite)
}
+ protected lazy val cache: Parser[LogicalPlan] =
+ (CACHE ^^^ true | UNCACHE ^^^ false) ~ TABLE ~ ident ^^ {
+ case doCache ~ _ ~ tableName => CacheCommand(tableName, doCache)
+ }
+
protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",")
protected lazy val projection: Parser[Expression] =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 3cf163f9a9a75..d177339d40ae5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -175,6 +175,8 @@ package object dsl {
def where(condition: Expression) = Filter(condition, logicalPlan)
+ def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan)
+
def join(
otherPlan: LogicalPlan,
joinType: JoinType = Inner,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 420303408451f..c074b7bb01e57 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -76,7 +76,8 @@ trait CaseConversionExpression {
type EvaluatedType = Any
def convert(v: String): String
-
+
+ override def foldable: Boolean = child.foldable
def nullable: Boolean = child.nullable
def dataType: DataType = StringType
@@ -142,6 +143,8 @@ case class RLike(left: Expression, right: Expression)
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
override def convert(v: String): String = v.toUpperCase()
+
+ override def toString() = s"Upper($child)"
}
/**
@@ -150,4 +153,6 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
override def convert(v: String): String = v.toLowerCase()
+
+ override def toString() = s"Lower($child)"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index ccb8245cc2e7d..25a347bec0e4c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -29,12 +29,15 @@ import org.apache.spark.sql.catalyst.types._
object Optimizer extends RuleExecutor[LogicalPlan] {
val batches =
+ Batch("Combine Limits", FixedPoint(100),
+ CombineLimits) ::
Batch("ConstantFolding", FixedPoint(100),
NullPropagation,
ConstantFolding,
BooleanSimplification,
SimplifyFilters,
- SimplifyCasts) ::
+ SimplifyCasts,
+ SimplifyCaseConversionExpressions) ::
Batch("Filter Pushdown", FixedPoint(100),
CombineFilters,
PushPredicateThroughProject,
@@ -104,8 +107,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
object NullPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case e @ Count(Literal(null, _)) => Literal(0, e.dataType)
- case e @ Sum(Literal(c, _)) if c == 0 => Literal(0, e.dataType)
+ case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
+ case e @ Sum(Literal(c, _)) if c == 0 => Cast(Literal(0L), e.dataType)
case e @ Average(Literal(c, _)) if c == 0 => Literal(0.0, e.dataType)
case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType)
case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType)
@@ -130,18 +133,6 @@ object NullPropagation extends Rule[LogicalPlan] {
case Literal(candidate, _) if candidate == v => true
case _ => false
})) => Literal(true, BooleanType)
- case e: UnaryMinus => e.child match {
- case Literal(null, _) => Literal(null, e.dataType)
- case _ => e
- }
- case e: Cast => e.child match {
- case Literal(null, _) => Literal(null, e.dataType)
- case _ => e
- }
- case e: Not => e.child match {
- case Literal(null, _) => Literal(null, e.dataType)
- case _ => e
- }
// Put exceptional cases above if any
case e: BinaryArithmetic => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
@@ -362,3 +353,29 @@ object SimplifyCasts extends Rule[LogicalPlan] {
case Cast(e, dataType) if e.dataType == dataType => e
}
}
+
+/**
+ * Combines two adjacent [[catalyst.plans.logical.Limit Limit]] operators into one, merging the
+ * expressions into one single expression.
+ */
+object CombineLimits extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case ll @ Limit(le, nl @ Limit(ne, grandChild)) =>
+ Limit(If(LessThan(ne, le), ne, le), grandChild)
+ }
+}
+
+/**
+ * Removes the inner [[catalyst.expressions.CaseConversionExpression]] that are unnecessary because
+ * the inner conversion is overwritten by the outer one.
+ */
+object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ case Upper(Upper(child)) => Upper(child)
+ case Upper(Lower(child)) => Upper(child)
+ case Lower(Upper(child)) => Lower(child)
+ case Lower(Lower(child)) => Lower(child)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 7eeb98aea6368..0933a31c362d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.types.{StringType, StructType}
+import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.catalyst.trees
abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
@@ -96,39 +96,6 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
def references = Set.empty
}
-/**
- * A logical node that represents a non-query command to be executed by the system. For example,
- * commands can be used by parsers to represent DDL operations.
- */
-abstract class Command extends LeafNode {
- self: Product =>
- def output: Seq[Attribute] = Seq.empty // TODO: SPARK-2081 should fix this
-}
-
-/**
- * Returned for commands supported by a given parser, but not catalyst. In general these are DDL
- * commands that are passed directly to another system.
- */
-case class NativeCommand(cmd: String) extends Command
-
-/**
- * Commands of the form "SET (key) (= value)".
- */
-case class SetCommand(key: Option[String], value: Option[String]) extends Command {
- override def output = Seq(
- AttributeReference("key", StringType, nullable = false)(),
- AttributeReference("value", StringType, nullable = false)()
- )
-}
-
-/**
- * Returned by a parser when the users only wants to see what query plan would be executed, without
- * actually performing the execution.
- */
-case class ExplainCommand(plan: LogicalPlan) extends Command {
- override def output = Seq(AttributeReference("plan", StringType, nullable = false)())
-}
-
/**
* A logical plan node with single child.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d3347b622f3d8..b777cf4249196 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -135,9 +135,9 @@ case class Aggregate(
def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet
}
-case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {
+case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
def output = child.output
- def references = limit.references
+ def references = limitExpr.references
}
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
new file mode 100644
index 0000000000000..3299e86b85941
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference}
+import org.apache.spark.sql.catalyst.types.StringType
+
+/**
+ * A logical node that represents a non-query command to be executed by the system. For example,
+ * commands can be used by parsers to represent DDL operations.
+ */
+abstract class Command extends LeafNode {
+ self: Product =>
+ def output: Seq[Attribute] = Seq.empty
+}
+
+/**
+ * Returned for commands supported by a given parser, but not catalyst. In general these are DDL
+ * commands that are passed directly to another system.
+ */
+case class NativeCommand(cmd: String) extends Command {
+ override def output =
+ Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)()))
+}
+
+/**
+ * Commands of the form "SET (key) (= value)".
+ */
+case class SetCommand(key: Option[String], value: Option[String]) extends Command {
+ override def output = Seq(
+ BoundReference(0, AttributeReference("key", StringType, nullable = false)()),
+ BoundReference(1, AttributeReference("value", StringType, nullable = false)()))
+}
+
+/**
+ * Returned by a parser when the users only wants to see what query plan would be executed, without
+ * actually performing the execution.
+ */
+case class ExplainCommand(plan: LogicalPlan) extends Command {
+ override def output =
+ Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)()))
+}
+
+/**
+ * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command.
+ */
+case class CacheCommand(tableName: String, doCache: Boolean) extends Command
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
new file mode 100644
index 0000000000000..714f01843c0f5
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+class CombiningLimitsSuite extends OptimizerTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Combine Limit", FixedPoint(2),
+ CombineLimits) ::
+ Batch("Constant Folding", FixedPoint(3),
+ NullPropagation,
+ ConstantFolding,
+ BooleanSimplification) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ test("limits: combines two limits") {
+ val originalQuery =
+ testRelation
+ .select('a)
+ .limit(10)
+ .limit(5)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .limit(5).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("limits: combines three limits") {
+ val originalQuery =
+ testRelation
+ .select('a)
+ .limit(2)
+ .limit(7)
+ .limit(5)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .limit(2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index 20dfba847790c..6efc0e211eb21 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.catalyst.types.{DoubleType, IntegerType}
+import org.apache.spark.sql.catalyst.types._
// For implicit conversions
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -173,4 +173,63 @@ class ConstantFoldingSuite extends OptimizerTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("Constant folding test: expressions have null literals") {
+ val originalQuery =
+ testRelation
+ .select(
+ IsNull(Literal(null)) as 'c1,
+ IsNotNull(Literal(null)) as 'c2,
+
+ GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3,
+ GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4,
+ GetField(
+ Literal(null, StructType(Seq(StructField("a", IntegerType, true)))),
+ "a") as 'c5,
+
+ UnaryMinus(Literal(null, IntegerType)) as 'c6,
+ Cast(Literal(null), IntegerType) as 'c7,
+ Not(Literal(null, BooleanType)) as 'c8,
+
+ Add(Literal(null, IntegerType), 1) as 'c9,
+ Add(1, Literal(null, IntegerType)) as 'c10,
+
+ Equals(Literal(null, IntegerType), 1) as 'c11,
+ Equals(1, Literal(null, IntegerType)) as 'c12,
+
+ Like(Literal(null, StringType), "abc") as 'c13,
+ Like("abc", Literal(null, StringType)) as 'c14,
+
+ Upper(Literal(null, StringType)) as 'c15)
+
+ val optimized = Optimize(originalQuery.analyze)
+
+ val correctAnswer =
+ testRelation
+ .select(
+ Literal(true) as 'c1,
+ Literal(false) as 'c2,
+
+ Literal(null, IntegerType) as 'c3,
+ Literal(null, IntegerType) as 'c4,
+ Literal(null, IntegerType) as 'c5,
+
+ Literal(null, IntegerType) as 'c6,
+ Literal(null, IntegerType) as 'c7,
+ Literal(null, BooleanType) as 'c8,
+
+ Literal(null, IntegerType) as 'c9,
+ Literal(null, IntegerType) as 'c10,
+
+ Literal(null, BooleanType) as 'c11,
+ Literal(null, BooleanType) as 'c12,
+
+ Literal(null, BooleanType) as 'c13,
+ Literal(null, BooleanType) as 'c14,
+
+ Literal(null, StringType) as 'c15)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 02cc665f8a8c7..1f67c80e54906 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -20,14 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.plans.LeftOuter
import org.apache.spark.sql.catalyst.plans.RightOuter
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.junit.Test
class FilterPushdownSuite extends OptimizerTest {
@@ -164,7 +161,7 @@ class FilterPushdownSuite extends OptimizerTest {
comparePlans(optimized, correctAnswer)
}
-
+
test("joins: push down left outer join #1") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala
new file mode 100644
index 0000000000000..df1409fe7baee
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+/* Implicit conversions */
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+
+class SimplifyCaseConversionExpressionsSuite extends OptimizerTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Simplify CaseConversionExpressions", Once,
+ SimplifyCaseConversionExpressions) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.string)
+
+ test("simplify UPPER(UPPER(str))") {
+ val originalQuery =
+ testRelation
+ .select(Upper(Upper('a)) as 'u)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select(Upper('a) as 'u)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("simplify UPPER(LOWER(str))") {
+ val originalQuery =
+ testRelation
+ .select(Upper(Lower('a)) as 'u)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select(Upper('a) as 'u)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("simplify LOWER(UPPER(str))") {
+ val originalQuery =
+ testRelation
+ .select(Lower(Upper('a)) as 'l)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer = testRelation
+ .select(Lower('a) as 'l)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("simplify LOWER(LOWER(str))") {
+ val originalQuery =
+ testRelation
+ .select(Lower(Lower('a)) as 'l)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer = testRelation
+ .select(Lower('a) as 'l)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 021e0e8245a0d..378ff54531118 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -31,10 +31,10 @@ import org.apache.spark.sql.catalyst.{ScalaReflection, dsl}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
-import org.apache.spark.sql.catalyst.plans.logical.{SetCommand, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
+import org.apache.spark.sql.columnar.InMemoryRelation
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.SparkStrategies
@@ -147,14 +147,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @group userf
*/
- def sql(sqlText: String): SchemaRDD = {
- val result = new SchemaRDD(this, parseSql(sqlText))
- // We force query optimization to happen right away instead of letting it happen lazily like
- // when using the query DSL. This is so DDL commands behave as expected. This is only
- // generates the RDD lineage for DML queries, but do not perform any execution.
- result.queryExecution.toRdd
- result
- }
+ def sql(sqlText: String): SchemaRDD = new SchemaRDD(this, parseSql(sqlText))
/** Returns the specified table as a SchemaRDD */
def table(tableName: String): SchemaRDD =
@@ -166,10 +159,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
val useCompression =
sparkContext.conf.getBoolean("spark.sql.inMemoryColumnarStorage.compressed", false)
val asInMemoryRelation =
- InMemoryColumnarTableScan(
- currentTable.output, executePlan(currentTable).executedPlan, useCompression)
+ InMemoryRelation(useCompression, executePlan(currentTable).executedPlan)
- catalog.registerTable(None, tableName, SparkLogicalPlan(asInMemoryRelation))
+ catalog.registerTable(None, tableName, asInMemoryRelation)
}
/** Removes the specified table from the in-memory cache. */
@@ -177,17 +169,26 @@ class SQLContext(@transient val sparkContext: SparkContext)
EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) match {
// This is kind of a hack to make sure that if this was just an RDD registered as a table,
// we reregister the RDD as a table.
- case SparkLogicalPlan(inMem @ InMemoryColumnarTableScan(_, e: ExistingRdd, _)) =>
+ case inMem @ InMemoryRelation(_, _, e: ExistingRdd) =>
inMem.cachedColumnBuffers.unpersist()
catalog.unregisterTable(None, tableName)
catalog.registerTable(None, tableName, SparkLogicalPlan(e))
- case SparkLogicalPlan(inMem: InMemoryColumnarTableScan) =>
+ case inMem: InMemoryRelation =>
inMem.cachedColumnBuffers.unpersist()
catalog.unregisterTable(None, tableName)
case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan")
}
}
+ /** Returns true if the table is currently cached in-memory. */
+ def isCached(tableName: String): Boolean = {
+ val relation = catalog.lookupRelation(None, tableName)
+ EliminateAnalysisOperators(relation) match {
+ case _: InMemoryRelation => true
+ case _ => false
+ }
+ }
+
protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext = self.sparkContext
@@ -199,6 +200,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
PartialAggregation ::
LeftSemiJoin ::
HashJoin ::
+ InMemoryScans ::
ParquetOperations ::
BasicOperators ::
CartesianProduct ::
@@ -250,8 +252,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] val planner = new SparkPlanner
@transient
- protected[sql] lazy val emptyResult =
- sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
+ protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1)
/**
* Prepares a planned SparkPlan for execution by binding references to specific ordinals, and
@@ -271,22 +272,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected abstract class QueryExecution {
def logical: LogicalPlan
- def eagerlyProcess(plan: LogicalPlan): RDD[Row] = plan match {
- case SetCommand(key, value) =>
- // Only this case needs to be executed eagerly. The other cases will
- // be taken care of when the actual results are being extracted.
- // In the case of HiveContext, sqlConf is overridden to also pass the
- // pair into its HiveConf.
- if (key.isDefined && value.isDefined) {
- set(key.get, value.get)
- }
- // It doesn't matter what we return here, since this is only used
- // to force the evaluation to happen eagerly. To query the results,
- // one must use SchemaRDD operations to extract them.
- emptyResult
- case _ => executedPlan.execute()
- }
-
lazy val analyzed = analyzer(logical)
lazy val optimizedPlan = optimizer(analyzed)
// TODO: Don't just pick the first one...
@@ -294,12 +279,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
/** Internal version of the RDD. Avoids copies and has no schema */
- lazy val toRdd: RDD[Row] = {
- logical match {
- case s: SetCommand => eagerlyProcess(s)
- case _ => executedPlan.execute()
- }
- }
+ lazy val toRdd: RDD[Row] = executedPlan.execute()
protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
@@ -321,7 +301,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* TODO: We only support primitive types, add support for nested types.
*/
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
- val schema = rdd.first.map { case (fieldName, obj) =>
+ val schema = rdd.first().map { case (fieldName, obj) =>
val dataType = obj.getClass match {
case c: Class[_] if c == classOf[java.lang.String] => StringType
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 8855c4e876917..821ac850ac3f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -97,7 +97,7 @@ import java.util.{Map => JMap}
@AlphaComponent
class SchemaRDD(
@transient val sqlContext: SQLContext,
- @transient protected[spark] val logicalPlan: LogicalPlan)
+ @transient val baseLogicalPlan: LogicalPlan)
extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike {
def baseSchemaRDD = this
@@ -178,14 +178,18 @@ class SchemaRDD(
def orderBy(sortExprs: SortOrder*): SchemaRDD =
new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan))
+ @deprecated("use limit with integer argument", "1.1.0")
+ def limit(limitExpr: Expression): SchemaRDD =
+ new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
+
/**
- * Limits the results by the given expressions.
+ * Limits the results by the given integer.
* {{{
* schemaRDD.limit(10)
* }}}
*/
- def limit(limitExpr: Expression): SchemaRDD =
- new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
+ def limit(limitNum: Int): SchemaRDD =
+ new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan))
/**
* Performs a grouping followed by an aggregation.
@@ -374,6 +378,8 @@ class SchemaRDD(
override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
+ override def take(num: Int): Array[Row] = limit(num).collect()
+
// =======================================================================
// Base RDD functions that do NOT change schema
// =======================================================================
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
index 3a895e15a4508..656be965a8fd9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
@@ -20,13 +20,14 @@ package org.apache.spark.sql
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.SparkLogicalPlan
/**
* Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java)
*/
private[sql] trait SchemaRDDLike {
@transient val sqlContext: SQLContext
- @transient protected[spark] val logicalPlan: LogicalPlan
+ @transient val baseLogicalPlan: LogicalPlan
private[sql] def baseSchemaRDD: SchemaRDD
@@ -48,7 +49,17 @@ private[sql] trait SchemaRDDLike {
*/
@transient
@DeveloperApi
- lazy val queryExecution = sqlContext.executePlan(logicalPlan)
+ lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan)
+
+ @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match {
+ // For various commands (like DDL) and queries with side effects, we force query optimization to
+ // happen right away to let these side effects take place eagerly.
+ case _: Command | _: InsertIntoTable | _: InsertIntoCreatedTable | _: WriteToFile =>
+ queryExecution.toRdd
+ SparkLogicalPlan(queryExecution.executedPlan)
+ case _ =>
+ baseLogicalPlan
+ }
override def toString =
s"""${super.toString}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
index 22f57b758dd02..aff6ffe9f3478 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
@@ -37,7 +37,7 @@ import org.apache.spark.storage.StorageLevel
*/
class JavaSchemaRDD(
@transient val sqlContext: SQLContext,
- @transient protected[spark] val logicalPlan: LogicalPlan)
+ @transient val baseLogicalPlan: LogicalPlan)
extends JavaRDDLike[Row, JavaRDD[Row]]
with SchemaRDDLike {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index fdf28e1bb1261..e1e4f24c6c66c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -17,18 +17,29 @@
package org.apache.spark.sql.columnar
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{SparkPlan, LeafNode}
import org.apache.spark.sql.Row
import org.apache.spark.SparkConf
-private[sql] case class InMemoryColumnarTableScan(
- attributes: Seq[Attribute],
- child: SparkPlan,
- useCompression: Boolean)
- extends LeafNode {
+object InMemoryRelation {
+ def apply(useCompression: Boolean, child: SparkPlan): InMemoryRelation =
+ new InMemoryRelation(child.output, useCompression, child)
+}
- override def output: Seq[Attribute] = attributes
+private[sql] case class InMemoryRelation(
+ output: Seq[Attribute],
+ useCompression: Boolean,
+ child: SparkPlan)
+ extends LogicalPlan with MultiInstanceRelation {
+
+ override def children = Seq.empty
+ override def references = Set.empty
+
+ override def newInstance() =
+ new InMemoryRelation(output.map(_.newInstance), useCompression, child).asInstanceOf[this.type]
lazy val cachedColumnBuffers = {
val output = child.output
@@ -55,14 +66,26 @@ private[sql] case class InMemoryColumnarTableScan(
cached.count()
cached
}
+}
+
+private[sql] case class InMemoryColumnarTableScan(
+ attributes: Seq[Attribute],
+ relation: InMemoryRelation)
+ extends LeafNode {
+
+ override def output: Seq[Attribute] = attributes
override def execute() = {
- cachedColumnBuffers.mapPartitions { iterator =>
+ relation.cachedColumnBuffers.mapPartitions { iterator =>
val columnBuffers = iterator.next()
assert(!iterator.hasNext)
new Iterator[Row] {
- val columnAccessors = columnBuffers.map(ColumnAccessor(_))
+ // Find the ordinals of the requested columns. If none are requested, use the first.
+ val requestedColumns =
+ if (attributes.isEmpty) Seq(0) else attributes.map(relation.output.indexOf(_))
+
+ val columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_))
val nextRow = new GenericMutableRow(columnAccessors.length)
override def next() = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 4613df103943d..07967fe75e882 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -77,8 +77,6 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)
SparkLogicalPlan(
alreadyPlanned match {
case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
- case scan @ InMemoryColumnarTableScan(output, _, _) =>
- scan.copy(attributes = output.map(_.newInstance))
case _ => sys.error("Multiple instance of the same relation detected.")
}).asInstanceOf[this.type]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 0455748d40eec..2233216a6ec52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,13 +17,14 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.{SQLConf, SQLContext, execution}
+import org.apache.spark.sql.{SQLContext, execution}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.parquet._
+import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>
@@ -156,7 +157,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil
- case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => {
+ case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
val prunePushedDownFilters =
if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
(filters: Seq[Expression]) => {
@@ -185,12 +186,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
filters,
prunePushedDownFilters,
ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil
- }
case _ => Nil
}
}
+ object InMemoryScans extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case PhysicalOperation(projectList, filters, mem: InMemoryRelation) =>
+ pruneFilterProject(
+ projectList,
+ filters,
+ identity[Seq[Expression]], // No filters are pushed down.
+ InMemoryColumnarTableScan(_, mem)) :: Nil
+ case _ => Nil
+ }
+ }
+
// Can we automate these 'pass through' operations?
object BasicOperators extends Strategy {
def numPartitions = self.numPartitions
@@ -237,12 +249,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case class CommandStrategy(context: SQLContext) extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.SetCommand(key, value) =>
- Seq(execution.SetCommandPhysical(key, value, plan.output)(context))
+ Seq(execution.SetCommand(key, value, plan.output)(context))
case logical.ExplainCommand(child) =>
- val qe = context.executePlan(child)
- Seq(execution.ExplainCommandPhysical(qe.executedPlan, plan.output)(context))
+ val executedPlan = context.executePlan(child).executedPlan
+ Seq(execution.ExplainCommand(executedPlan, plan.output)(context))
+ case logical.CacheCommand(tableName, cache) =>
+ Seq(execution.CacheCommand(tableName, cache)(context))
case _ => Nil
}
}
-
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 9364506691f38..0377290af5926 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -22,46 +22,94 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute}
+trait Command {
+ /**
+ * A concrete command should override this lazy field to wrap up any side effects caused by the
+ * command or any other computation that should be evaluated exactly once. The value of this field
+ * can be used as the contents of the corresponding RDD generated from the physical plan of this
+ * command.
+ *
+ * The `execute()` method of all the physical command classes should reference `sideEffectResult`
+ * so that the command can be executed eagerly right after the command query is created.
+ */
+ protected[sql] lazy val sideEffectResult: Seq[Any] = Seq.empty[Any]
+}
+
/**
* :: DeveloperApi ::
*/
@DeveloperApi
-case class SetCommandPhysical(key: Option[String], value: Option[String], output: Seq[Attribute])
- (@transient context: SQLContext) extends LeafNode {
- def execute(): RDD[Row] = (key, value) match {
- // Set value for key k; the action itself would
- // have been performed in QueryExecution eagerly.
- case (Some(k), Some(v)) => context.emptyResult
+case class SetCommand(
+ key: Option[String], value: Option[String], output: Seq[Attribute])(
+ @transient context: SQLContext)
+ extends LeafNode with Command {
+
+ override protected[sql] lazy val sideEffectResult: Seq[(String, String)] = (key, value) match {
+ // Set value for key k.
+ case (Some(k), Some(v)) =>
+ context.set(k, v)
+ Array(k -> v)
+
// Query the value bound to key k.
- case (Some(k), None) =>
- val resultString = context.getOption(k) match {
- case Some(v) => s"$k=$v"
- case None => s"$k is undefined"
- }
- context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](resultString))), 1)
+ case (Some(k), _) =>
+ Array(k -> context.getOption(k).getOrElse(""))
+
// Query all key-value pairs that are set in the SQLConf of the context.
case (None, None) =>
- val pairs = context.getAll
- val rows = pairs.map { case (k, v) =>
- new GenericRow(Array[Any](s"$k=$v"))
- }.toSeq
- // Assume config parameters can fit into one split (machine) ;)
- context.sparkContext.parallelize(rows, 1)
- // The only other case is invalid semantics and is impossible.
- case _ => context.emptyResult
+ context.getAll
+
+ case _ =>
+ throw new IllegalArgumentException()
}
+
+ def execute(): RDD[Row] = {
+ val rows = sideEffectResult.map { case (k, v) => new GenericRow(Array[Any](k, v)) }
+ context.sparkContext.parallelize(rows, 1)
+ }
+
+ override def otherCopyArgs = context :: Nil
}
/**
* :: DeveloperApi ::
*/
@DeveloperApi
-case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute])
- (@transient context: SQLContext) extends UnaryNode {
+case class ExplainCommand(
+ child: SparkPlan, output: Seq[Attribute])(
+ @transient context: SQLContext)
+ extends UnaryNode with Command {
+
+ // Actually "EXPLAIN" command doesn't cause any side effect.
+ override protected[sql] lazy val sideEffectResult: Seq[String] = this.toString.split("\n")
+
def execute(): RDD[Row] = {
- val planString = new GenericRow(Array[Any](child.toString))
- context.sparkContext.parallelize(Seq(planString))
+ val explanation = sideEffectResult.mkString("\n")
+ context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](explanation))), 1)
}
override def otherCopyArgs = context :: Nil
}
+
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: SQLContext)
+ extends LeafNode with Command {
+
+ override protected[sql] lazy val sideEffectResult = {
+ if (doCache) {
+ context.cacheTable(tableName)
+ } else {
+ context.uncacheTable(tableName)
+ }
+ Seq.empty[Any]
+ }
+
+ override def execute(): RDD[Row] = {
+ sideEffectResult
+ context.emptyResult
+ }
+
+ override def output: Seq[Attribute] = Seq.empty
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 88ff3d49a79b3..8d7a5ba59f96a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -169,7 +169,7 @@ case class LeftSemiJoinHash(
def execute() = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
- val hashTable = new java.util.HashSet[Row]()
+ val hashSet = new java.util.HashSet[Row]()
var currentRow: Row = null
// Create a Hash set of buildKeys
@@ -177,43 +177,17 @@ case class LeftSemiJoinHash(
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
- val keyExists = hashTable.contains(rowKey)
+ val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
- hashTable.add(rowKey)
+ hashSet.add(rowKey)
}
}
}
- new Iterator[Row] {
- private[this] var currentStreamedRow: Row = _
- private[this] var currentHashMatched: Boolean = false
-
- private[this] val joinKeys = streamSideKeyGenerator()
-
- override final def hasNext: Boolean =
- streamIter.hasNext && fetchNext()
-
- override final def next() = {
- currentStreamedRow
- }
-
- /**
- * Searches the streamed iterator for the next row that has at least one match in hashtable.
- *
- * @return true if the search is successful, and false the streamed iterator runs out of
- * tuples.
- */
- private final def fetchNext(): Boolean = {
- currentHashMatched = false
- while (!currentHashMatched && streamIter.hasNext) {
- currentStreamedRow = streamIter.next()
- if (!joinKeys(currentStreamedRow).anyNull) {
- currentHashMatched = hashTable.contains(joinKeys.currentValue)
- }
- }
- currentHashMatched
- }
- }
+ val joinKeys = streamSideKeyGenerator()
+ streamIter.filter(current => {
+ !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
+ })
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 0331f90272a99..c794da4da4069 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -18,8 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
-import org.apache.spark.sql.execution.SparkLogicalPlan
+import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
import org.apache.spark.sql.test.TestSQLContext
class CachedTableSuite extends QueryTest {
@@ -34,7 +33,7 @@ class CachedTableSuite extends QueryTest {
)
TestSQLContext.table("testData").queryExecution.analyzed match {
- case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching
+ case _ : InMemoryRelation => // Found evidence of caching
case noCache => fail(s"No cache node found in plan $noCache")
}
@@ -46,7 +45,7 @@ class CachedTableSuite extends QueryTest {
)
TestSQLContext.table("testData").queryExecution.analyzed match {
- case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) =>
+ case cachePlan: InMemoryRelation =>
fail(s"Table still cached after uncache: $cachePlan")
case noCache => // Table uncached successfully
}
@@ -61,13 +60,33 @@ class CachedTableSuite extends QueryTest {
test("SELECT Star Cached Table") {
TestSQLContext.sql("SELECT * FROM testData").registerAsTable("selectStar")
TestSQLContext.cacheTable("selectStar")
- TestSQLContext.sql("SELECT * FROM selectStar")
+ TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect()
TestSQLContext.uncacheTable("selectStar")
}
test("Self-join cached") {
+ val unCachedAnswer =
+ TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect()
TestSQLContext.cacheTable("testData")
- TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key")
+ checkAnswer(
+ TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"),
+ unCachedAnswer.toSeq)
TestSQLContext.uncacheTable("testData")
}
+
+ test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") {
+ TestSQLContext.sql("CACHE TABLE testData")
+ TestSQLContext.table("testData").queryExecution.executedPlan match {
+ case _: InMemoryColumnarTableScan => // Found evidence of caching
+ case _ => fail(s"Table 'testData' should be cached")
+ }
+ assert(TestSQLContext.isCached("testData"), "Table 'testData' should be cached")
+
+ TestSQLContext.sql("UNCACHE TABLE testData")
+ TestSQLContext.table("testData").queryExecution.executedPlan match {
+ case _: InMemoryColumnarTableScan => fail(s"Table 'testData' should not be cached")
+ case _ => // Found evidence of uncaching
+ }
+ assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index 5eb73a4eff980..08293f7f0ca30 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -28,6 +28,7 @@ class SQLConfSuite extends QueryTest {
val testVal = "test.val.0"
test("programmatic ways of basic setting and getting") {
+ clear()
assert(getOption(testKey).isEmpty)
assert(getAll.toSet === Set())
@@ -48,6 +49,7 @@ class SQLConfSuite extends QueryTest {
}
test("parse SQL set commands") {
+ clear()
sql(s"set $testKey=$testVal")
assert(get(testKey, testVal + "_") == testVal)
assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index de02bbc7e4700..e9360b0fc7910 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -141,7 +141,7 @@ class SQLQuerySuite extends QueryTest {
sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"),
Seq((2147483645.0,1),(2.0,2)))
}
-
+
test("count") {
checkAnswer(
sql("SELECT COUNT(*) FROM testData2"),
@@ -332,7 +332,7 @@ class SQLQuerySuite extends QueryTest {
(3, "C"),
(4, "D")))
}
-
+
test("system function upper()") {
checkAnswer(
sql("SELECT n,UPPER(l) FROM lowerCaseData"),
@@ -349,7 +349,7 @@ class SQLQuerySuite extends QueryTest {
(2, "ABC"),
(3, null)))
}
-
+
test("system function lower()") {
checkAnswer(
sql("SELECT N,LOWER(L) FROM upperCaseData"),
@@ -382,26 +382,27 @@ class SQLQuerySuite extends QueryTest {
sql(s"SET $testKey=$testVal")
checkAnswer(
sql("SET"),
- Seq(Seq(s"$testKey=$testVal"))
+ Seq(Seq(testKey, testVal))
)
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
checkAnswer(
sql("set"),
Seq(
- Seq(s"$testKey=$testVal"),
- Seq(s"${testKey + testKey}=${testVal + testVal}"))
+ Seq(testKey, testVal),
+ Seq(testKey + testKey, testVal + testVal))
)
// "set key"
checkAnswer(
sql(s"SET $testKey"),
- Seq(Seq(s"$testKey=$testVal"))
+ Seq(Seq(testKey, testVal))
)
checkAnswer(
sql(s"SET $nonexistentKey"),
- Seq(Seq(s"$nonexistentKey is undefined"))
+ Seq(Seq(nonexistentKey, ""))
)
+ clear()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 31c5dfba92954..86727b93f3659 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -28,14 +28,14 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("simple columnar query") {
val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan
- val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan, true))
+ val scan = InMemoryRelation(useCompression = true, plan)
checkAnswer(scan, testData.collect().toSeq)
}
test("projection") {
val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
- val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan, true))
+ val scan = InMemoryRelation(useCompression = true, plan)
checkAnswer(scan, testData.collect().map {
case Row(key: Int, value: String) => value -> key
@@ -44,7 +44,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan
- val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan, true))
+ val scan = InMemoryRelation(useCompression = true, plan)
checkAnswer(scan, testData.collect().toSeq)
checkAnswer(scan, testData.collect().toSeq)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 64978215542ec..96e0ec5136331 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -15,8 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql
-package hive
+package org.apache.spark.sql.hive
import java.io.{BufferedReader, File, InputStreamReader, PrintStream}
import java.util.{ArrayList => JArrayList}
@@ -32,12 +31,13 @@ import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog}
-import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.QueryExecutionException
+import org.apache.spark.sql.execution.{Command => PhysicalCommand}
/**
* Starts up an instance of hive where metadata is stored locally. An in-process metadata data is
@@ -71,14 +71,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/**
* Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD.
*/
- def hiveql(hqlQuery: String): SchemaRDD = {
- val result = new SchemaRDD(this, HiveQl.parseSql(hqlQuery))
- // We force query optimization to happen right away instead of letting it happen lazily like
- // when using the query DSL. This is so DDL commands behave as expected. This is only
- // generates the RDD lineage for DML queries, but does not perform any execution.
- result.queryExecution.toRdd
- result
- }
+ def hiveql(hqlQuery: String): SchemaRDD = new SchemaRDD(this, HiveQl.parseSql(hqlQuery))
/** An alias for `hiveql`. */
def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery)
@@ -164,7 +157,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/**
* Runs the specified SQL query using Hive.
*/
- protected def runSqlHive(sql: String): Seq[String] = {
+ protected[sql] def runSqlHive(sql: String): Seq[String] = {
val maxResults = 100000
val results = runHive(sql, 100000)
// It is very confusing when you only get back some of the results...
@@ -228,8 +221,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
override val strategies: Seq[Strategy] = Seq(
CommandStrategy(self),
+ HiveCommandStrategy(self),
TakeOrdered,
ParquetOperations,
+ InMemoryScans,
HiveTableScans,
DataSinks,
Scripts,
@@ -251,25 +246,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
override lazy val optimizedPlan =
optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))
- override lazy val toRdd: RDD[Row] = {
- def processCmd(cmd: String): RDD[Row] = {
- val output = runSqlHive(cmd)
- if (output.size == 0) {
- emptyResult
- } else {
- val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]]))
- sparkContext.parallelize(asRows, 1)
- }
- }
-
- logical match {
- case s: SetCommand => eagerlyProcess(s)
- case _ => analyzed match {
- case NativeCommand(cmd) => processCmd(cmd)
- case _ => executedPlan.execute().map(_.copy())
- }
- }
- }
+ override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy())
protected val primitiveTypes =
Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
@@ -297,7 +274,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
struct.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
}.mkString("{", ",", "}")
- case (seq: Seq[_], ArrayType(typ))=>
+ case (seq: Seq[_], ArrayType(typ)) =>
seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
case (map: Map[_,_], MapType(kType, vType)) =>
map.map {
@@ -313,10 +290,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
* Returns the result as a hive compatible sequence of strings. For native commands, the
* execution is simply passed back to Hive.
*/
- def stringResult(): Seq[String] = analyzed match {
- case NativeCommand(cmd) => runSqlHive(cmd)
- case ExplainCommand(plan) => executePlan(plan).toString.split("\n")
- case query =>
+ def stringResult(): Seq[String] = executedPlan match {
+ case command: PhysicalCommand =>
+ command.sideEffectResult.map(_.toString)
+
+ case other =>
val result: Seq[Seq[Any]] = toRdd.collect().toSeq
// We need the types so we can output struct field names
val types = analyzed.output.map(_.dataType)
@@ -327,8 +305,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
override def simpleString: String =
logical match {
- case _: NativeCommand => ""
- case _: SetCommand => ""
+ case _: NativeCommand => ""
+ case _: SetCommand => ""
case _ => executedPlan.toString
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index a91b520765349..68284344afd55 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.execution.SparkLogicalPlan
import org.apache.spark.sql.hive.execution.{HiveTableScan, InsertIntoHiveTable}
-import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
+import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
/* Implicit conversions */
import scala.collection.JavaConversions._
@@ -130,8 +130,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
castChildOutput(p, table, child)
- case p @ logical.InsertIntoTable(SparkLogicalPlan(InMemoryColumnarTableScan(
- _, HiveTableScan(_, table, _), _)), _, child, _) =>
+ case p @ logical.InsertIntoTable(
+ InMemoryRelation(_, _,
+ HiveTableScan(_, table, _)), _, child, _) =>
castChildOutput(p, table, child)
}
@@ -236,6 +237,7 @@ object HiveMetastoreTypes extends RegexParsers {
case BinaryType => "binary"
case BooleanType => "boolean"
case DecimalType => "decimal"
+ case TimestampType => "timestamp"
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 4e74d9bc909fa..b745d8ffd8f17 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -218,15 +218,19 @@ private[hive] object HiveQl {
case Array(key, value) => // "set key=value"
SetCommand(Some(key), Some(value))
}
- } else if (sql.toLowerCase.startsWith("add jar")) {
+ } else if (sql.trim.toLowerCase.startsWith("cache table")) {
+ CacheCommand(sql.drop(12).trim, true)
+ } else if (sql.trim.toLowerCase.startsWith("uncache table")) {
+ CacheCommand(sql.drop(14).trim, false)
+ } else if (sql.trim.toLowerCase.startsWith("add jar")) {
AddJar(sql.drop(8))
- } else if (sql.toLowerCase.startsWith("add file")) {
+ } else if (sql.trim.toLowerCase.startsWith("add file")) {
AddFile(sql.drop(9))
- } else if (sql.startsWith("dfs")) {
+ } else if (sql.trim.startsWith("dfs")) {
DfsCommand(sql)
- } else if (sql.startsWith("source")) {
+ } else if (sql.trim.startsWith("source")) {
SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath })
- } else if (sql.startsWith("!")) {
+ } else if (sql.trim.startsWith("!")) {
ShellCommand(sql.drop(1))
} else {
val tree = getAst(sql)
@@ -839,11 +843,11 @@ private[hive] object HiveQl {
case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg))
case Token("TOK_FUNCTION", Token(MAX(), Nil) :: arg :: Nil) => Max(nodeToExpr(arg))
case Token("TOK_FUNCTION", Token(MIN(), Nil) :: arg :: Nil) => Min(nodeToExpr(arg))
-
+
/* System functions about string operations */
case Token("TOK_FUNCTION", Token(UPPER(), Nil) :: arg :: Nil) => Upper(nodeToExpr(arg))
case Token("TOK_FUNCTION", Token(LOWER(), Nil) :: arg :: Nil) => Lower(nodeToExpr(arg))
-
+
/* Casts */
case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), StringType)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 8b51957162e04..0ac0ee9071f36 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive.execution._
-import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
+import org.apache.spark.sql.columnar.InMemoryRelation
private[hive] trait HiveStrategies {
// Possibly being too clever with types here... or not clever enough.
@@ -44,8 +44,9 @@ private[hive] trait HiveStrategies {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) =>
InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
- case logical.InsertIntoTable(SparkLogicalPlan(InMemoryColumnarTableScan(
- _, HiveTableScan(_, table, _), _)), partition, child, overwrite) =>
+ case logical.InsertIntoTable(
+ InMemoryRelation(_, _,
+ HiveTableScan(_, table, _)), partition, child, overwrite) =>
InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
case _ => Nil
}
@@ -75,4 +76,12 @@ private[hive] trait HiveStrategies {
Nil
}
}
+
+ case class HiveCommandStrategy(context: HiveContext) extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.NativeCommand(sql) =>
+ NativeCommand(sql, plan.output)(context) :: Nil
+ case _ => Nil
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
index 041e813598d1b..9386008d02d51 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -32,7 +32,7 @@ import org.apache.hadoop.hive.serde2.avro.AvroSerDe
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, NativeCommand}
+import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.hive._
@@ -103,7 +103,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {
val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) {
new File("src" + File.separator + "test" + File.separator + "resources" + File.separator)
} else {
- new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" +
+ new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" +
File.separator + "resources")
}
@@ -130,6 +130,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {
override lazy val analyzed = {
val describedTables = logical match {
case NativeCommand(describedTable(tbl)) => tbl :: Nil
+ case CacheCommand(tbl, _) => tbl :: Nil
case _ => Nil
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala
index 29b4b9b006e45..a839231449161 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala
@@ -32,14 +32,15 @@ import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Serializer}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred._
+import org.apache.spark
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{BooleanType, DataType}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive._
-import org.apache.spark.{TaskContext, SparkException}
import org.apache.spark.util.MutablePair
+import org.apache.spark.{TaskContext, SparkException}
/* Implicits */
import scala.collection.JavaConversions._
@@ -57,7 +58,7 @@ case class HiveTableScan(
attributes: Seq[Attribute],
relation: MetastoreRelation,
partitionPruningPred: Option[Expression])(
- @transient val sc: HiveContext)
+ @transient val context: HiveContext)
extends LeafNode
with HiveInspectors {
@@ -75,7 +76,7 @@ case class HiveTableScan(
}
@transient
- val hadoopReader = new HadoopTableReader(relation.tableDesc, sc)
+ val hadoopReader = new HadoopTableReader(relation.tableDesc, context)
/**
* The hive object inspector for this table, which can be used to extract values from the
@@ -156,7 +157,7 @@ case class HiveTableScan(
hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames)
}
- addColumnMetadataToConf(sc.hiveconf)
+ addColumnMetadataToConf(context.hiveconf)
@transient
def inputRdd = if (!relation.hiveQlTable.isPartitioned) {
@@ -428,3 +429,26 @@ case class InsertIntoHiveTable(
sc.sparkContext.makeRDD(Nil, 1)
}
}
+
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+case class NativeCommand(
+ sql: String, output: Seq[Attribute])(
+ @transient context: HiveContext)
+ extends LeafNode with Command {
+
+ override protected[sql] lazy val sideEffectResult: Seq[String] = context.runSqlHive(sql)
+
+ override def execute(): RDD[spark.sql.Row] = {
+ if (sideEffectResult.size == 0) {
+ context.emptyResult
+ } else {
+ val rows = sideEffectResult.map(r => new GenericRow(Array[Any](r)))
+ context.sparkContext.parallelize(rows, 1)
+ }
+ }
+
+ override def otherCopyArgs = context :: Nil
+}
diff --git a/sql/hive/src/test/resources/golden/semijoin-0-1631b71327abf75b96116036b977b26c b/sql/hive/src/test/resources/golden/semijoin-0-1631b71327abf75b96116036b977b26c
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c b/sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c
new file mode 100644
index 0000000000000..2ed47ab83dd02
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c
@@ -0,0 +1,11 @@
+0 val_0
+0 val_0
+0 val_0
+2 val_2
+4 val_4
+5 val_5
+5 val_5
+5 val_5
+8 val_8
+9 val_9
+10 val_10
diff --git a/sql/hive/src/test/resources/golden/semijoin-10-ffd4fb3a903a6725ccb97d5451a3fec6 b/sql/hive/src/test/resources/golden/semijoin-10-ffd4fb3a903a6725ccb97d5451a3fec6
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330 b/sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330
new file mode 100644
index 0000000000000..a24bd8c6379e3
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330
@@ -0,0 +1,8 @@
+0 val_0
+0 val_0
+0 val_0
+4 val_2
+8 val_4
+10 val_5
+10 val_5
+10 val_5
diff --git a/sql/hive/src/test/resources/golden/semijoin-12-6d93a9d332ba490835b17f261a5467df b/sql/hive/src/test/resources/golden/semijoin-12-6d93a9d332ba490835b17f261a5467df
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-13-18282d38b6efc0017089ab89b661764f b/sql/hive/src/test/resources/golden/semijoin-13-18282d38b6efc0017089ab89b661764f
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-14-19cfcefb10e1972bec0ffd421cd79de7 b/sql/hive/src/test/resources/golden/semijoin-14-19cfcefb10e1972bec0ffd421cd79de7
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589 b/sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589
new file mode 100644
index 0000000000000..03c61a908b071
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589
@@ -0,0 +1,11 @@
+val_0
+val_0
+val_0
+val_10
+val_2
+val_4
+val_5
+val_5
+val_5
+val_8
+val_9
diff --git a/sql/hive/src/test/resources/golden/semijoin-16-d3a72a90515ac4a8d8e9ac923bcda3d b/sql/hive/src/test/resources/golden/semijoin-16-d3a72a90515ac4a8d8e9ac923bcda3d
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019 b/sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019
new file mode 100644
index 0000000000000..2dcdfd1217ced
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019
@@ -0,0 +1,3 @@
+0 val_0
+0 val_0
+0 val_0
diff --git a/sql/hive/src/test/resources/golden/semijoin-18-f7b2ce472443982e32d954cbb5c96765 b/sql/hive/src/test/resources/golden/semijoin-18-f7b2ce472443982e32d954cbb5c96765
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f b/sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f
new file mode 100644
index 0000000000000..a3670515e8cc2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f
@@ -0,0 +1,3 @@
+val_10
+val_8
+val_9
diff --git a/sql/hive/src/test/resources/golden/semijoin-2-deb9c3286ae8e851b1fdb270085b16bc b/sql/hive/src/test/resources/golden/semijoin-2-deb9c3286ae8e851b1fdb270085b16bc
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-20-b7a8ebaeb42b2eaba7d97cadc3fd96c1 b/sql/hive/src/test/resources/golden/semijoin-20-b7a8ebaeb42b2eaba7d97cadc3fd96c1
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-21-480418a0646cf7260b494b9eb4821bb6 b/sql/hive/src/test/resources/golden/semijoin-21-480418a0646cf7260b494b9eb4821bb6
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-22-b6aebd98f7636cda7b24e0bf84d7ba41 b/sql/hive/src/test/resources/golden/semijoin-22-b6aebd98f7636cda7b24e0bf84d7ba41
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda b/sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda
new file mode 100644
index 0000000000000..72bc6a6a88f6e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda
@@ -0,0 +1,5 @@
+4 val_2
+8 val_4
+10 val_5
+10 val_5
+10 val_5
diff --git a/sql/hive/src/test/resources/golden/semijoin-24-d16b37134de78980b2bf96029e8265c3 b/sql/hive/src/test/resources/golden/semijoin-24-d16b37134de78980b2bf96029e8265c3
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b b/sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b
new file mode 100644
index 0000000000000..d89ea1757c712
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b
@@ -0,0 +1,19 @@
+0
+0
+0
+0
+0
+0
+2
+4
+4
+5
+5
+5
+8
+8
+9
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-26-f1d3bab29f1ebafa148dbe3816e1da25 b/sql/hive/src/test/resources/golden/semijoin-26-f1d3bab29f1ebafa148dbe3816e1da25
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d b/sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d
new file mode 100644
index 0000000000000..dbbdae75a52a4
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d
@@ -0,0 +1,4 @@
+0 val_0
+0 val_0
+0 val_0
+8 val_8
diff --git a/sql/hive/src/test/resources/golden/semijoin-28-b56400f6d9372f353cf7292a2182e963 b/sql/hive/src/test/resources/golden/semijoin-28-b56400f6d9372f353cf7292a2182e963
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5 b/sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5
new file mode 100644
index 0000000000000..07c61afb5124b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5
@@ -0,0 +1,14 @@
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+4 val_4 4 val_2
+8 val_8 8 val_4
+10 val_10 10 val_5
+10 val_10 10 val_5
+10 val_10 10 val_5
diff --git a/sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389 b/sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389
new file mode 100644
index 0000000000000..bf51e8f5d9eb5
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389
@@ -0,0 +1,11 @@
+0 val_0
+0 val_0
+0 val_0
+4 val_2
+8 val_4
+10 val_5
+10 val_5
+10 val_5
+16 val_8
+18 val_9
+20 val_10
diff --git a/sql/hive/src/test/resources/golden/semijoin-30-dd901d00fce5898b03a57cbc3028a70a b/sql/hive/src/test/resources/golden/semijoin-30-dd901d00fce5898b03a57cbc3028a70a
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67 b/sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67
new file mode 100644
index 0000000000000..d6283e34d8ffc
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67
@@ -0,0 +1,14 @@
+0 val_0
+0 val_0
+0 val_0
+0 val_0
+0 val_0
+0 val_0
+2 val_2
+4 val_4
+5 val_5
+5 val_5
+5 val_5
+8 val_8
+9 val_9
+10 val_10
diff --git a/sql/hive/src/test/resources/golden/semijoin-32-23017c7663f2710265a7e2a4a1606d39 b/sql/hive/src/test/resources/golden/semijoin-32-23017c7663f2710265a7e2a4a1606d39
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5 b/sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5
new file mode 100644
index 0000000000000..080180f9d0f0e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5
@@ -0,0 +1,14 @@
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-34-5e1b832090ab73c141c1167d5b25a490 b/sql/hive/src/test/resources/golden/semijoin-34-5e1b832090ab73c141c1167d5b25a490
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99 b/sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99
new file mode 100644
index 0000000000000..4a64d5c625790
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99
@@ -0,0 +1,26 @@
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-36-b1159823dca8025926407f8aa921238d b/sql/hive/src/test/resources/golden/semijoin-36-b1159823dca8025926407f8aa921238d
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0 b/sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0
new file mode 100644
index 0000000000000..1420c786fb228
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0
@@ -0,0 +1,29 @@
+NULL
+NULL
+NULL
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-38-f37547c73a48ce3ba089531b176e6ba b/sql/hive/src/test/resources/golden/semijoin-38-f37547c73a48ce3ba089531b176e6ba
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8 b/sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8
new file mode 100644
index 0000000000000..1420c786fb228
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8
@@ -0,0 +1,29 @@
+NULL
+NULL
+NULL
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-4-dfdad5a2742f93e8ea888191460809c0 b/sql/hive/src/test/resources/golden/semijoin-4-dfdad5a2742f93e8ea888191460809c0
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-40-32071a51e2ba6e86b1c5e40de55aae63 b/sql/hive/src/test/resources/golden/semijoin-40-32071a51e2ba6e86b1c5e40de55aae63
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092 b/sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092
new file mode 100644
index 0000000000000..aef9483bb0bc9
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092
@@ -0,0 +1,29 @@
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
+16
+18
+20
diff --git a/sql/hive/src/test/resources/golden/semijoin-42-6b4257a74fca627785c967c99547f4c0 b/sql/hive/src/test/resources/golden/semijoin-42-6b4257a74fca627785c967c99547f4c0
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81 b/sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81
new file mode 100644
index 0000000000000..0bc413ef2e09e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81
@@ -0,0 +1,31 @@
+NULL
+NULL
+NULL
+NULL
+NULL
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-44-945aaa3a24359ef73acab1e99500d5ea b/sql/hive/src/test/resources/golden/semijoin-44-945aaa3a24359ef73acab1e99500d5ea
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b b/sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b
new file mode 100644
index 0000000000000..3131e64446f66
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b
@@ -0,0 +1,42 @@
+NULL
+NULL
+NULL
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+2
+4
+4
+5
+5
+5
+8
+8
+9
+10
+10
+10
+10
+10
+10
+10
+10
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-46-620e01f81f6e5254b4bbe8fab4043ec0 b/sql/hive/src/test/resources/golden/semijoin-46-620e01f81f6e5254b4bbe8fab4043ec0
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c b/sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c
new file mode 100644
index 0000000000000..ff30bedb81861
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c
@@ -0,0 +1,35 @@
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
+10
+10
+10
+10
+10
+10
+16
+18
+20
diff --git a/sql/hive/src/test/resources/golden/semijoin-48-8a04442e84f99a584c2882d0af8c25d8 b/sql/hive/src/test/resources/golden/semijoin-48-8a04442e84f99a584c2882d0af8c25d8
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-49-df1d6705d3624be72036318a6b42f04c b/sql/hive/src/test/resources/golden/semijoin-49-df1d6705d3624be72036318a6b42f04c
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d b/sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d
new file mode 100644
index 0000000000000..60f6eacee9b14
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d
@@ -0,0 +1,22 @@
+0 val_0
+0 val_0
+0 val_0
+0 val_0
+0 val_0
+0 val_0
+2 val_2
+4 val_2
+4 val_4
+5 val_5
+5 val_5
+5 val_5
+8 val_4
+8 val_8
+9 val_9
+10 val_10
+10 val_5
+10 val_5
+10 val_5
+16 val_8
+18 val_9
+20 val_10
diff --git a/sql/hive/src/test/resources/golden/semijoin-6-90bb51b1330230d10a14fb7517457aa0 b/sql/hive/src/test/resources/golden/semijoin-6-90bb51b1330230d10a14fb7517457aa0
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-7-333d72e8bce6d11a35fc7a30418f225b b/sql/hive/src/test/resources/golden/semijoin-7-333d72e8bce6d11a35fc7a30418f225b
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-8-d46607be851a6f4e27e98cbbefdee994 b/sql/hive/src/test/resources/golden/semijoin-8-d46607be851a6f4e27e98cbbefdee994
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2 b/sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2
new file mode 100644
index 0000000000000..5baaac9bebf6d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2
@@ -0,0 +1,6 @@
+0 val_0
+0 val_0
+0 val_0
+4 val_4
+8 val_8
+10 val_10
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index f9a162ef4e3c0..3132d0112c708 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.hive
import org.apache.spark.sql.execution.SparkLogicalPlan
-import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
+import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
import org.apache.spark.sql.hive.execution.HiveComparisonTest
import org.apache.spark.sql.hive.test.TestHive
@@ -34,7 +34,7 @@ class CachedTableSuite extends HiveComparisonTest {
test("check that table is cached and uncache") {
TestHive.table("src").queryExecution.analyzed match {
- case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching
+ case _ : InMemoryRelation => // Found evidence of caching
case noCache => fail(s"No cache node found in plan $noCache")
}
TestHive.uncacheTable("src")
@@ -45,7 +45,7 @@ class CachedTableSuite extends HiveComparisonTest {
test("make sure table is uncached") {
TestHive.table("src").queryExecution.analyzed match {
- case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) =>
+ case cachePlan: InMemoryRelation =>
fail(s"Table still cached after uncache: $cachePlan")
case noCache => // Table uncached successfully
}
@@ -56,4 +56,20 @@ class CachedTableSuite extends HiveComparisonTest {
TestHive.uncacheTable("src")
}
}
+
+ test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") {
+ TestHive.hql("CACHE TABLE src")
+ TestHive.table("src").queryExecution.executedPlan match {
+ case _: InMemoryColumnarTableScan => // Found evidence of caching
+ case _ => fail(s"Table 'src' should be cached")
+ }
+ assert(TestHive.isCached("src"), "Table 'src' should be cached")
+
+ TestHive.hql("UNCACHE TABLE src")
+ TestHive.table("src").queryExecution.executedPlan match {
+ case _: InMemoryColumnarTableScan => fail(s"Table 'src' should not be cached")
+ case _ => // Found evidence of uncaching
+ }
+ assert(!TestHive.isCached("src"), "Table 'src' should not be cached")
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 357c7e654bd20..24c929ff7430d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -24,6 +24,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen}
import org.apache.spark.sql.Logging
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand => LogicalNativeCommand}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.hive.test.TestHive
@@ -141,7 +142,7 @@ abstract class HiveComparisonTest
// Hack: Hive simply prints the result of a SET command to screen,
// and does not return it as a query answer.
case _: SetCommand => Seq("0")
- case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "")
+ case _: LogicalNativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "")
case _: ExplainCommand => answer
case plan => if (isSorted(plan)) answer else answer.sorted
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index fb8f272d5abfe..ee194dbcb77b2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -172,7 +172,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"case_sensitivity",
// Flaky test, Hive sometimes returns different set of 10 rows.
- "lateral_view_outer"
+ "lateral_view_outer",
+
+ // After stop taking the `stringOrError` route, exceptions are thrown from these cases.
+ // See SPARK-2129 for details.
+ "join_view",
+ "mergejoins_mixed"
)
/**
@@ -476,7 +481,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"join_reorder3",
"join_reorder4",
"join_star",
- "join_view",
"lateral_view",
"lateral_view_cp",
"lateral_view_ppd",
@@ -507,7 +511,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"merge1",
"merge2",
"mergejoins",
- "mergejoins_mixed",
"multigroupby_singlemr",
"multi_insert_gby",
"multi_insert_gby3",
@@ -597,6 +600,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"select_unquote_and",
"select_unquote_not",
"select_unquote_or",
+ "semijoin",
"serde_regex",
"serde_reported_schema",
"set_variable_sub",
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 6c239b02ed09a..0d656c556965d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -17,9 +17,11 @@
package org.apache.spark.sql.hive.execution
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.hive.test.TestHive._
+import scala.util.Try
+
import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.{SchemaRDD, execution, Row}
/**
* A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
@@ -162,16 +164,60 @@ class HiveQuerySuite extends HiveComparisonTest {
hql("SELECT * FROM src").toString
}
+ private val explainCommandClassName =
+ classOf[execution.ExplainCommand].getSimpleName.stripSuffix("$")
+
+ def isExplanation(result: SchemaRDD) = {
+ val explanation = result.select('plan).collect().map { case Row(plan: String) => plan }
+ explanation.size == 1 && explanation.head.startsWith(explainCommandClassName)
+ }
+
test("SPARK-1704: Explain commands as a SchemaRDD") {
hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+
val rdd = hql("explain select key, count(value) from src group by key")
- assert(rdd.collect().size == 1)
- assert(rdd.toString.contains("ExplainCommand"))
- assert(rdd.filter(row => row.toString.contains("ExplainCommand")).collect().size == 0,
- "actual contents of the result should be the plans of the query to be explained")
+ assert(isExplanation(rdd))
+
TestHive.reset()
}
+ test("Query Hive native command execution result") {
+ val tableName = "test_native_commands"
+
+ val q0 = hql(s"DROP TABLE IF EXISTS $tableName")
+ assert(q0.count() == 0)
+
+ val q1 = hql(s"CREATE TABLE $tableName(key INT, value STRING)")
+ assert(q1.count() == 0)
+
+ val q2 = hql("SHOW TABLES")
+ val tables = q2.select('result).collect().map { case Row(table: String) => table }
+ assert(tables.contains(tableName))
+
+ val q3 = hql(s"DESCRIBE $tableName")
+ assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) {
+ q3.select('result).collect().map { case Row(fieldDesc: String) =>
+ fieldDesc.split("\t").map(_.trim)
+ }
+ }
+
+ val q4 = hql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key")
+ assert(isExplanation(q4))
+
+ TestHive.reset()
+ }
+
+ test("Exactly once semantics for DDL and command statements") {
+ val tableName = "test_exactly_once"
+ val q0 = hql(s"CREATE TABLE $tableName(key INT, value STRING)")
+
+ // If the table was not created, the following assertion would fail
+ assert(Try(table(tableName)).isSuccess)
+
+ // If the CREATE TABLE command got executed again, the following assertion would fail
+ assert(Try(q0.count()).isSuccess)
+ }
+
test("parse HQL set commands") {
// Adapted from its SQL counterpart.
val testKey = "spark.sql.key.usedfortestonly"
@@ -195,52 +241,69 @@ class HiveQuerySuite extends HiveComparisonTest {
test("SET commands semantics for a HiveContext") {
// Adapted from its SQL counterpart.
val testKey = "spark.sql.key.usedfortestonly"
- var testVal = "test.val.0"
+ val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
- def fromRows(row: Array[Row]): Array[String] = row.map(_.getString(0))
+ def rowsToPairs(rows: Array[Row]) = rows.map { case Row(key: String, value: String) =>
+ key -> value
+ }
clear()
// "set" itself returns all config variables currently specified in SQLConf.
- assert(hql("set").collect().size == 0)
+ assert(hql("SET").collect().size == 0)
+
+ assertResult(Array(testKey -> testVal)) {
+ rowsToPairs(hql(s"SET $testKey=$testVal").collect())
+ }
- // "set key=val"
- hql(s"SET $testKey=$testVal")
- assert(fromRows(hql("SET").collect()) sameElements Array(s"$testKey=$testVal"))
assert(hiveconf.get(testKey, "") == testVal)
+ assertResult(Array(testKey -> testVal)) {
+ rowsToPairs(hql("SET").collect())
+ }
hql(s"SET ${testKey + testKey}=${testVal + testVal}")
- assert(fromRows(hql("SET").collect()) sameElements
- Array(
- s"$testKey=$testVal",
- s"${testKey + testKey}=${testVal + testVal}"))
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
+ assertResult(Array(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
+ rowsToPairs(hql("SET").collect())
+ }
// "set key"
- assert(fromRows(hql(s"SET $testKey").collect()) sameElements
- Array(s"$testKey=$testVal"))
- assert(fromRows(hql(s"SET $nonexistentKey").collect()) sameElements
- Array(s"$nonexistentKey is undefined"))
+ assertResult(Array(testKey -> testVal)) {
+ rowsToPairs(hql(s"SET $testKey").collect())
+ }
+
+ assertResult(Array(nonexistentKey -> "")) {
+ rowsToPairs(hql(s"SET $nonexistentKey").collect())
+ }
// Assert that sql() should have the same effects as hql() by repeating the above using sql().
clear()
- assert(sql("set").collect().size == 0)
+ assert(sql("SET").collect().size == 0)
+
+ assertResult(Array(testKey -> testVal)) {
+ rowsToPairs(sql(s"SET $testKey=$testVal").collect())
+ }
- sql(s"SET $testKey=$testVal")
- assert(fromRows(sql("SET").collect()) sameElements Array(s"$testKey=$testVal"))
assert(hiveconf.get(testKey, "") == testVal)
+ assertResult(Array(testKey -> testVal)) {
+ rowsToPairs(sql("SET").collect())
+ }
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
- assert(fromRows(sql("SET").collect()) sameElements
- Array(
- s"$testKey=$testVal",
- s"${testKey + testKey}=${testVal + testVal}"))
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
+ assertResult(Array(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
+ rowsToPairs(sql("SET").collect())
+ }
- assert(fromRows(sql(s"SET $testKey").collect()) sameElements
- Array(s"$testKey=$testVal"))
- assert(fromRows(sql(s"SET $nonexistentKey").collect()) sameElements
- Array(s"$nonexistentKey is undefined"))
+ assertResult(Array(testKey -> testVal)) {
+ rowsToPairs(sql(s"SET $testKey").collect())
+ }
+
+ assertResult(Array(nonexistentKey -> "")) {
+ rowsToPairs(sql(s"SET $nonexistentKey").collect())
+ }
+
+ clear()
}
// Put tests that depend on specific Hive settings before these last two test,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala
index 86753360a07e4..a0aeacbc733bd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala
@@ -27,6 +27,7 @@ private[streaming] class ContextWaiter {
}
def notifyStop() = synchronized {
+ stopped = true
notifyAll()
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
index 303d149d285e1..d9ac3c91f6e36 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
@@ -29,7 +29,6 @@ import org.scalatest.FunSuite
import org.scalatest.concurrent.Timeouts
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
-import scala.language.postfixOps
/** Testsuite for testing the network receiver behavior */
class NetworkReceiverSuite extends FunSuite with Timeouts {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index cd86019f63e7e..7b33d3b235466 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -223,6 +223,18 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
}
}
+ test("awaitTermination after stop") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ val inputStream = addInputStream(ssc)
+ inputStream.map(x => x).register()
+
+ failAfter(10000 millis) {
+ ssc.start()
+ ssc.stop()
+ ssc.awaitTermination()
+ }
+ }
+
test("awaitTermination with error in task") {
ssc = new StreamingContext(master, appName, batchDuration)
val inputStream = addInputStream(ssc)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
index ef0efa552ceaf..2861f5335ae36 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -27,12 +27,12 @@ import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.streaming.scheduler._
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
import org.apache.spark.Logging
-class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers {
+class StreamingListenerSuite extends TestSuiteBase with Matchers {
val input = (1 to 4).map(Seq(_)).toSeq
val operation = (d: DStream[Int]) => d.map(x => x)
diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
index 6a261e19a35cd..03a73f92b275e 100644
--- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
+++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
@@ -40,74 +40,78 @@ object GenerateMIMAIgnore {
private val classLoader = Thread.currentThread().getContextClassLoader
private val mirror = runtimeMirror(classLoader)
- private def classesPrivateWithin(packageName: String): Set[String] = {
+
+ private def isDeveloperApi(sym: unv.Symbol) =
+ sym.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.DeveloperApi])
+
+ private def isExperimental(sym: unv.Symbol) =
+ sym.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.Experimental])
+
+
+ private def isPackagePrivate(sym: unv.Symbol) =
+ !sym.privateWithin.fullName.startsWith("")
+
+ private def isPackagePrivateModule(moduleSymbol: unv.ModuleSymbol) =
+ !moduleSymbol.privateWithin.fullName.startsWith("")
+
+ /**
+ * For every class checks via scala reflection if the class itself or contained members
+ * have DeveloperApi or Experimental annotations or they are package private.
+ * Returns the tuple of such classes and members.
+ */
+ private def privateWithin(packageName: String): (Set[String], Set[String]) = {
val classes = getClasses(packageName)
val ignoredClasses = mutable.HashSet[String]()
+ val ignoredMembers = mutable.HashSet[String]()
- def isPackagePrivate(className: String) = {
+ for (className <- classes) {
try {
- /* Couldn't figure out if it's possible to determine a-priori whether a given symbol
- is a module or class. */
-
- val privateAsClass = mirror
- .classSymbol(Class.forName(className, false, classLoader))
- .privateWithin
- .fullName
- .startsWith(packageName)
-
- val privateAsModule = mirror
- .staticModule(className)
- .privateWithin
- .fullName
- .startsWith(packageName)
-
- privateAsClass || privateAsModule
- } catch {
- case _: Throwable => {
- println("Error determining visibility: " + className)
- false
+ val classSymbol = mirror.classSymbol(Class.forName(className, false, classLoader))
+ val moduleSymbol = mirror.staticModule(className) // TODO: see if it is necessary.
+ val directlyPrivateSpark =
+ isPackagePrivate(classSymbol) || isPackagePrivateModule(moduleSymbol)
+ val developerApi = isDeveloperApi(classSymbol)
+ val experimental = isExperimental(classSymbol)
+
+ /* Inner classes defined within a private[spark] class or object are effectively
+ invisible, so we account for them as package private. */
+ lazy val indirectlyPrivateSpark = {
+ val maybeOuter = className.toString.takeWhile(_ != '$')
+ if (maybeOuter != className) {
+ isPackagePrivate(mirror.classSymbol(Class.forName(maybeOuter, false, classLoader))) ||
+ isPackagePrivateModule(mirror.staticModule(maybeOuter))
+ } else {
+ false
+ }
+ }
+ if (directlyPrivateSpark || indirectlyPrivateSpark || developerApi || experimental) {
+ ignoredClasses += className
+ } else {
+ // check if this class has package-private/annotated members.
+ ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol)
}
- }
- }
- def isDeveloperApi(className: String) = {
- try {
- val clazz = mirror.classSymbol(Class.forName(className, false, classLoader))
- clazz.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.DeveloperApi])
} catch {
- case _: Throwable => {
- println("Error determining Annotations: " + className)
- false
- }
+ case _: Throwable => println("Error instrumenting class:" + className)
}
}
+ (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet)
+ }
- for (className <- classes) {
- val directlyPrivateSpark = isPackagePrivate(className)
- val developerApi = isDeveloperApi(className)
-
- /* Inner classes defined within a private[spark] class or object are effectively
- invisible, so we account for them as package private. */
- val indirectlyPrivateSpark = {
- val maybeOuter = className.toString.takeWhile(_ != '$')
- if (maybeOuter != className) {
- isPackagePrivate(maybeOuter)
- } else {
- false
- }
- }
- if (directlyPrivateSpark || indirectlyPrivateSpark || developerApi) {
- ignoredClasses += className
- }
- }
- ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet
+ private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = {
+ classSymbol.typeSignature.members
+ .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName)
}
def main(args: Array[String]) {
- scala.tools.nsc.io.File(".generated-mima-excludes").
- writeAll(classesPrivateWithin("org.apache.spark").mkString("\n"))
- println("Created : .generated-mima-excludes in current directory.")
+ val (privateClasses, privateMembers) = privateWithin("org.apache.spark")
+ scala.tools.nsc.io.File(".generated-mima-class-excludes").
+ writeAll(privateClasses.mkString("\n"))
+ println("Created : .generated-mima-class-excludes in current directory.")
+ scala.tools.nsc.io.File(".generated-mima-member-excludes").
+ writeAll(privateMembers.mkString("\n"))
+ println("Created : .generated-mima-member-excludes in current directory.")
}
@@ -140,10 +144,17 @@ object GenerateMIMAIgnore {
* Get all classes in a package from a jar file.
*/
private def getClassesFromJar(jarPath: String, packageName: String) = {
+ import scala.collection.mutable
val jar = new JarFile(new File(jarPath))
val enums = jar.entries().map(_.getName).filter(_.startsWith(packageName))
- val classes = for (entry <- enums if entry.endsWith(".class"))
- yield Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader)
+ val classes = mutable.HashSet[Class[_]]()
+ for (entry <- enums if entry.endsWith(".class")) {
+ try {
+ classes += Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader)
+ } catch {
+ case _: Throwable => println("Unable to load:" + entry)
+ }
+ }
classes
}
}
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 8f0ecb855718e..1cc9c33cd2d02 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -277,7 +277,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
yarnAllocator.allocateContainers(
math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0))
ApplicationMaster.incrementAllocatorLoop(1)
- Thread.sleep(100)
+ Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL)
}
} finally {
// In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT,
@@ -416,6 +416,7 @@ object ApplicationMaster {
// TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
// optimal as more containers are available. Might need to handle this better.
private val ALLOCATOR_LOOP_WAIT_COUNT = 30
+ private val ALLOCATE_HEARTBEAT_INTERVAL = 100
def incrementAllocatorLoop(by: Int) {
val count = yarnAllocatorLoop.getAndAdd(by)
@@ -467,13 +468,22 @@ object ApplicationMaster {
})
}
- // Wait for initialization to complete and atleast 'some' nodes can get allocated.
+ modified
+ }
+
+
+ /**
+ * Returns when we've either
+ * 1) received all the requested executors,
+ * 2) waited ALLOCATOR_LOOP_WAIT_COUNT * ALLOCATE_HEARTBEAT_INTERVAL ms,
+ * 3) hit an error that causes us to terminate trying to get containers.
+ */
+ def waitForInitialAllocations() {
yarnAllocatorLoop.synchronized {
while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) {
yarnAllocatorLoop.wait(1000L)
}
}
- modified
}
def main(argStrings: Array[String]) {
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 8226207de42b8..4ccddc214c8ad 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -85,7 +85,6 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
def run() {
val appId = runApp()
monitorApplication(appId)
- System.exit(0)
}
def logClusterResourceDetails() {
@@ -179,8 +178,17 @@ object Client {
System.setProperty("SPARK_YARN_MODE", "true")
val sparkConf = new SparkConf
- val args = new ClientArguments(argStrings, sparkConf)
- new Client(args, sparkConf).run
+ try {
+ val args = new ClientArguments(argStrings, sparkConf)
+ new Client(args, sparkConf).run()
+ } catch {
+ case e: Exception => {
+ Console.err.println(e.getMessage)
+ System.exit(1)
+ }
+ }
+
+ System.exit(0)
}
}
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
index a3bd91590fc25..b6ecae1e652fe 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
@@ -271,6 +271,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
.asInstanceOf[FinishApplicationMasterRequest]
finishReq.setAppAttemptId(appAttemptId)
finishReq.setFinishApplicationStatus(status)
+ finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", ""))
resourceManager.finishApplicationMaster(finishReq)
}
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index b2c413b6d267c..fd3ef9e1fa2de 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -125,11 +125,11 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) {
case Nil =>
if (userClass == null) {
- printUsageAndExit(1)
+ throw new IllegalArgumentException(getUsageMessage())
}
case _ =>
- printUsageAndExit(1, args)
+ throw new IllegalArgumentException(getUsageMessage(args))
}
}
@@ -138,11 +138,10 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) {
}
- def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
- if (unknownParam != null) {
- System.err.println("Unknown/unsupported param " + unknownParam)
- }
- System.err.println(
+ def getUsageMessage(unknownParam: Any = null): String = {
+ val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else ""
+
+ message +
"Usage: org.apache.spark.deploy.yarn.Client [options] \n" +
"Options:\n" +
" --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster mode)\n" +
@@ -158,8 +157,5 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) {
" --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
" --files files Comma separated list of files to be distributed with the job.\n" +
" --archives archives Comma separated list of archives to be distributed with the job."
- )
- System.exit(exitCode)
}
-
}
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
index 801e8b381588f..6861b503000ca 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
@@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn
import java.io.File
import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException}
-import java.nio.ByteBuffer
import scala.collection.JavaConversions._
import scala.collection.mutable.{HashMap, ListBuffer, Map}
@@ -37,8 +36,8 @@ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.apache.hadoop.yarn.util.{Apps, Records}
-import org.apache.spark.{Logging, SparkConf, SparkContext}
+import org.apache.hadoop.yarn.util.Records
+import org.apache.spark.{SparkException, Logging, SparkConf, SparkContext}
/**
* The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The
@@ -80,7 +79,7 @@ trait ClientBase extends Logging {
).foreach { case(cond, errStr) =>
if (cond) {
logError(errStr)
- args.printUsageAndExit(1)
+ throw new IllegalArgumentException(args.getUsageMessage())
}
}
}
@@ -95,15 +94,20 @@ trait ClientBase extends Logging {
// If we have requested more then the clusters max for a single resource then exit.
if (args.executorMemory > maxMem) {
- logError("Required executor memory (%d MB), is above the max threshold (%d MB) of this cluster.".
- format(args.executorMemory, maxMem))
- System.exit(1)
+ val errorMessage =
+ "Required executor memory (%d MB), is above the max threshold (%d MB) of this cluster."
+ .format(args.executorMemory, maxMem)
+
+ logError(errorMessage)
+ throw new IllegalArgumentException(errorMessage)
}
val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD
if (amMem > maxMem) {
- logError("Required AM memory (%d) is above the max threshold (%d) of this cluster".
- format(args.amMemory, maxMem))
- System.exit(1)
+
+ val errorMessage = "Required AM memory (%d) is above the max threshold (%d) of this cluster."
+ .format(args.amMemory, maxMem)
+ logError(errorMessage)
+ throw new IllegalArgumentException(errorMessage)
}
// We could add checks to make sure the entire cluster has enough resources but that involves
@@ -169,14 +173,13 @@ trait ClientBase extends Logging {
destPath
}
- def qualifyForLocal(localURI: URI): Path = {
+ private def qualifyForLocal(localURI: URI): Path = {
var qualifiedURI = localURI
- // If not specified assume these are in the local filesystem to keep behavior like Hadoop
+ // If not specified, assume these are in the local filesystem to keep behavior like Hadoop
if (qualifiedURI.getScheme() == null) {
qualifiedURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(qualifiedURI)).toString)
}
- val qualPath = new Path(qualifiedURI)
- qualPath
+ new Path(qualifiedURI)
}
def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = {
@@ -188,8 +191,9 @@ trait ClientBase extends Logging {
val delegTokenRenewer = Master.getMasterPrincipal(conf)
if (UserGroupInformation.isSecurityEnabled()) {
if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) {
- logError("Can't get Master Kerberos principal for use as renewer")
- System.exit(1)
+ val errorMessage = "Can't get Master Kerberos principal for use as renewer"
+ logError(errorMessage)
+ throw new SparkException(errorMessage)
}
}
val dst = new Path(fs.getHomeDirectory(), appStagingDir)
@@ -305,13 +309,13 @@ trait ClientBase extends Logging {
val amMemory = calculateAMMemory(newApp)
- val JAVA_OPTS = ListBuffer[String]()
+ val javaOpts = ListBuffer[String]()
// Add Xmx for AM memory
- JAVA_OPTS += "-Xmx" + amMemory + "m"
+ javaOpts += "-Xmx" + amMemory + "m"
val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR)
- JAVA_OPTS += "-Djava.io.tmpdir=" + tmpDir
+ javaOpts += "-Djava.io.tmpdir=" + tmpDir
// TODO: Remove once cpuset version is pushed out.
// The context is, default gc for server class machines ends up using all cores to do gc -
@@ -325,11 +329,11 @@ trait ClientBase extends Logging {
if (useConcurrentAndIncrementalGC) {
// In our expts, using (default) throughput collector has severe perf ramifications in
// multi-tenant machines
- JAVA_OPTS += "-XX:+UseConcMarkSweepGC"
- JAVA_OPTS += "-XX:+CMSIncrementalMode"
- JAVA_OPTS += "-XX:+CMSIncrementalPacing"
- JAVA_OPTS += "-XX:CMSIncrementalDutyCycleMin=0"
- JAVA_OPTS += "-XX:CMSIncrementalDutyCycle=10"
+ javaOpts += "-XX:+UseConcMarkSweepGC"
+ javaOpts += "-XX:+CMSIncrementalMode"
+ javaOpts += "-XX:+CMSIncrementalPacing"
+ javaOpts += "-XX:CMSIncrementalDutyCycleMin=0"
+ javaOpts += "-XX:CMSIncrementalDutyCycle=10"
}
// SPARK_JAVA_OPTS is deprecated, but for backwards compatibility:
@@ -344,22 +348,22 @@ trait ClientBase extends Logging {
// If we are being launched in client mode, forward the spark-conf options
// onto the executor launcher
for ((k, v) <- sparkConf.getAll) {
- JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\""
+ javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\""
}
} else {
// If we are being launched in standalone mode, capture and forward any spark
// system properties (e.g. set by spark-class).
for ((k, v) <- sys.props.filterKeys(_.startsWith("spark"))) {
- JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\""
+ javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\""
}
- sys.props.get("spark.driver.extraJavaOptions").foreach(opts => JAVA_OPTS += opts)
- sys.props.get("spark.driver.libraryPath").foreach(p => JAVA_OPTS += s"-Djava.library.path=$p")
+ sys.props.get("spark.driver.extraJavaOptions").foreach(opts => javaOpts += opts)
+ sys.props.get("spark.driver.libraryPath").foreach(p => javaOpts += s"-Djava.library.path=$p")
}
- JAVA_OPTS += ClientBase.getLog4jConfiguration(localResources)
+ javaOpts += ClientBase.getLog4jConfiguration(localResources)
// Command for the ApplicationMaster
val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++
- JAVA_OPTS ++
+ javaOpts ++
Seq(args.amClass, "--class", args.userClass, "--jar ", args.userJar,
userArgsToString(args),
"--executor-memory", args.executorMemory.toString,
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
index 32f8861dc9503..43dbb2464f929 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
import org.apache.spark.{Logging, SparkConf}
@@ -46,19 +46,19 @@ trait ExecutorRunnableUtil extends Logging {
executorCores: Int,
localResources: HashMap[String, LocalResource]): List[String] = {
// Extra options for the JVM
- val JAVA_OPTS = ListBuffer[String]()
+ val javaOpts = ListBuffer[String]()
// Set the JVM memory
val executorMemoryString = executorMemory + "m"
- JAVA_OPTS += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " "
+ javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " "
// Set extra Java options for the executor, if defined
sys.props.get("spark.executor.extraJavaOptions").foreach { opts =>
- JAVA_OPTS += opts
+ javaOpts += opts
}
- JAVA_OPTS += "-Djava.io.tmpdir=" +
+ javaOpts += "-Djava.io.tmpdir=" +
new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR)
- JAVA_OPTS += ClientBase.getLog4jConfiguration(localResources)
+ javaOpts += ClientBase.getLog4jConfiguration(localResources)
// Certain configs need to be passed here because they are needed before the Executor
// registers with the Scheduler and transfers the spark configs. Since the Executor backend
@@ -66,10 +66,10 @@ trait ExecutorRunnableUtil extends Logging {
// authentication settings.
sparkConf.getAll.
filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }.
- foreach { case (k, v) => JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" }
+ foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" }
sparkConf.getAkkaConf.
- foreach { case (k, v) => JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" }
+ foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" }
// Commenting it out for now - so that people can refer to the properties if required. Remove
// it once cpuset version is pushed out.
@@ -88,11 +88,11 @@ trait ExecutorRunnableUtil extends Logging {
// multi-tennent machines
// The options are based on
// http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
- JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
- JAVA_OPTS += " -XX:+CMSIncrementalMode "
- JAVA_OPTS += " -XX:+CMSIncrementalPacing "
- JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
- JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+ javaOpts += " -XX:+UseConcMarkSweepGC "
+ javaOpts += " -XX:+CMSIncrementalMode "
+ javaOpts += " -XX:+CMSIncrementalPacing "
+ javaOpts += " -XX:CMSIncrementalDutyCycleMin=0 "
+ javaOpts += " -XX:CMSIncrementalDutyCycle=10 "
}
*/
@@ -104,7 +104,7 @@ trait ExecutorRunnableUtil extends Logging {
// TODO: If the OOM is not recoverable by rescheduling it on different node, then do
// 'something' to fail job ... akin to blacklisting trackers in mapred ?
"-XX:OnOutOfMemoryError='kill %p'") ++
- JAVA_OPTS ++
+ javaOpts ++
Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend",
masterAddress.toString,
slaveId.toString,
diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index a4638cc863611..39cdd2e8a522b 100644
--- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -33,10 +33,11 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration)
def this(sc: SparkContext) = this(sc, new Configuration())
- // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate
- // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?)
- // Subsequent creations are ignored - since nodes are already allocated by then.
-
+ // Nothing else for now ... initialize application master : which needs a SparkContext to
+ // determine how to allocate.
+ // Note that only the first creation of a SparkContext influences (and ideally, there must be
+ // only one SparkContext, right ?). Subsequent creations are ignored since executors are already
+ // allocated by then.
// By default, rack is unknown
override def getRackForHost(hostPort: String): Option[String] = {
@@ -48,6 +49,7 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration)
override def postStartHook() {
val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
if (sparkContextInitialized){
+ ApplicationMaster.waitForInitialAllocations()
// Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
Thread.sleep(3000L)
}
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 33a60d978c586..6244332f23737 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -19,13 +19,12 @@ package org.apache.spark.deploy.yarn
import java.io.IOException
import java.util.concurrent.CopyOnWriteArrayList
-import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
+import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.util.ShutdownHookManager
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.protocolrecords._
@@ -33,8 +32,7 @@ import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.AMRMClient
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.apache.hadoop.yarn.ipc.YarnRPC
-import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import org.apache.hadoop.yarn.util.ConverterUtils
import org.apache.hadoop.yarn.webapp.util.WebAppUtils
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext}
@@ -77,17 +75,18 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
// than user specified and /tmp.
System.setProperty("spark.local.dir", getLocalDirs())
- // set the web ui port to be ephemeral for yarn so we don't conflict with
+ // Set the web ui port to be ephemeral for yarn so we don't conflict with
// other spark processes running on the same box
System.setProperty("spark.ui.port", "0")
- // when running the AM, the Spark master is always "yarn-cluster"
+ // When running the AM, the Spark master is always "yarn-cluster"
System.setProperty("spark.master", "yarn-cluster")
- // Use priority 30 as it's higher then HDFS. It's same priority as MapReduce is using.
+ // Use priority 30 as it's higher than HDFS. It's the same priority MapReduce is using.
ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30)
- appAttemptId = getApplicationAttemptId()
+ appAttemptId = ApplicationMaster.getApplicationAttemptId()
+ logInfo("ApplicationAttemptId: " + appAttemptId)
isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts
amClient = AMRMClient.createAMRMClient()
amClient.init(yarnConf)
@@ -99,7 +98,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
ApplicationMaster.register(this)
// Call this to force generation of secret so it gets populated into the
- // hadoop UGI. This has to happen before the startUserClass which does a
+ // Hadoop UGI. This has to happen before the startUserClass which does a
// doAs in order for the credentials to be passed on to the executor containers.
val securityMgr = new SecurityManager(sparkConf)
@@ -121,7 +120,10 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
// Allocate all containers
allocateExecutors()
- // Wait for the user class to Finish
+ // Launch thread that will heartbeat to the RM so it won't think the app has died.
+ launchReporterThread()
+
+ // Wait for the user class to finish
userThread.join()
System.exit(0)
@@ -141,7 +143,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
"spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params)
}
- /** Get the Yarn approved local directories. */
+ // Get the Yarn approved local directories.
private def getLocalDirs(): String = {
// Hadoop 0.23 and 2.x have different Environment variable names for the
// local dirs, so lets check both. We assume one of the 2 is set.
@@ -150,18 +152,9 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
.orElse(Option(System.getenv("LOCAL_DIRS")))
localDirs match {
- case None => throw new Exception("Yarn Local dirs can't be empty")
+ case None => throw new Exception("Yarn local dirs can't be empty")
case Some(l) => l
}
- }
-
- private def getApplicationAttemptId(): ApplicationAttemptId = {
- val envs = System.getenv()
- val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name())
- val containerId = ConverterUtils.toContainerId(containerIdString)
- val appAttemptId = containerId.getApplicationAttemptId()
- logInfo("ApplicationAttemptId: " + appAttemptId)
- appAttemptId
}
private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
@@ -173,25 +166,23 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
logInfo("Starting the user JAR in a separate Thread")
val mainMethod = Class.forName(
args.userClass,
- false /* initialize */ ,
+ false,
Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]])
val t = new Thread {
override def run() {
-
- var successed = false
+ var succeeded = false
try {
// Copy
- var mainArgs: Array[String] = new Array[String](args.userArgs.size)
+ val mainArgs = new Array[String](args.userArgs.size)
args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size)
mainMethod.invoke(null, mainArgs)
- // some job script has "System.exit(0)" at the end, for example SparkPi, SparkLR
- // userThread will stop here unless it has uncaught exception thrown out
- // It need shutdown hook to set SUCCEEDED
- successed = true
+ // Some apps have "System.exit(0)" at the end. The user thread will stop here unless
+ // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED.
+ succeeded = true
} finally {
- logDebug("finishing main")
+ logDebug("Finishing main")
isLastAMRetry = true
- if (successed) {
+ if (succeeded) {
ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
} else {
ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.FAILED)
@@ -199,11 +190,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
}
}
}
+ t.setName("Driver")
t.start()
t
}
- // This need to happen before allocateExecutors()
+ // This needs to happen before allocateExecutors()
private def waitForSparkContextInitialized() {
logInfo("Waiting for Spark context initialization")
try {
@@ -231,7 +223,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
sparkContext.preferredNodeLocationData,
sparkContext.getConf)
} else {
- logWarning("Unable to retrieve SparkContext inspite of waiting for %d, maxNumTries = %d".
+ logWarning("Unable to retrieve SparkContext in spite of waiting for %d, maxNumTries = %d".
format(numTries * waitTime, maxNumTries))
this.yarnAllocator = YarnAllocationHandler.newAllocator(
yarnConf,
@@ -242,48 +234,37 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
}
}
} finally {
- // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT :
- // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks.
- ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
+ // In case of exceptions, etc - ensure that the loop in
+ // ApplicationMaster#sparkContextInitialized() breaks.
+ ApplicationMaster.doneWithInitialAllocations()
}
}
private def allocateExecutors() {
try {
- logInfo("Allocating " + args.numExecutors + " executors.")
- // Wait until all containers have finished
+ logInfo("Requesting" + args.numExecutors + " executors.")
+ // Wait until all containers have launched
yarnAllocator.addResourceRequests(args.numExecutors)
yarnAllocator.allocateResources()
// Exits the loop if the user thread exits.
+
+ var iters = 0
while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) {
checkNumExecutorsFailed()
allocateMissingExecutor()
yarnAllocator.allocateResources()
- ApplicationMaster.incrementAllocatorLoop(1)
- Thread.sleep(100)
+ if (iters == ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) {
+ ApplicationMaster.doneWithInitialAllocations()
+ }
+ Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL)
+ iters += 1
}
} finally {
- // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT,
- // so that the loop in ApplicationMaster#sparkContextInitialized() breaks.
- ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
+ // In case of exceptions, etc - ensure that the loop in
+ // ApplicationMaster#sparkContextInitialized() breaks.
+ ApplicationMaster.doneWithInitialAllocations()
}
logInfo("All executors have launched.")
-
- // Launch a progress reporter thread, else the app will get killed after expiration
- // (def: 10mins) timeout.
- if (userThread.isAlive) {
- // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses.
- val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
-
- // we want to be reasonably responsive without causing too many requests to RM.
- val schedulerInterval =
- sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000)
-
- // must be <= timeoutInterval / 2.
- val interval = math.min(timeoutInterval / 2, schedulerInterval)
-
- launchReporterThread(interval)
- }
}
private def allocateMissingExecutor() {
@@ -303,47 +284,35 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
}
}
- private def launchReporterThread(_sleepTime: Long): Thread = {
- val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime
+ private def launchReporterThread(): Thread = {
+ // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses.
+ val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+
+ // we want to be reasonably responsive without causing too many requests to RM.
+ val schedulerInterval =
+ sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000)
+
+ // must be <= timeoutInterval / 2.
+ val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval))
val t = new Thread {
override def run() {
while (userThread.isAlive) {
checkNumExecutorsFailed()
allocateMissingExecutor()
- sendProgress()
- Thread.sleep(sleepTime)
+ logDebug("Sending progress")
+ yarnAllocator.allocateResources()
+ Thread.sleep(interval)
}
}
}
// Setting to daemon status, though this is usually not a good idea.
t.setDaemon(true)
t.start()
- logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+ logInfo("Started progress reporter thread - heartbeat interval : " + interval)
t
}
- private def sendProgress() {
- logDebug("Sending progress")
- // Simulated with an allocate request with no nodes requested.
- yarnAllocator.allocateResources()
- }
-
- /*
- def printContainers(containers: List[Container]) = {
- for (container <- containers) {
- logInfo("Launching shell command on a new container."
- + ", containerId=" + container.getId()
- + ", containerNode=" + container.getNodeId().getHost()
- + ":" + container.getNodeId().getPort()
- + ", containerNodeURI=" + container.getNodeHttpAddress()
- + ", containerState" + container.getState()
- + ", containerResourceMemory"
- + container.getResource().getMemory())
- }
- }
- */
-
def finishApplicationMaster(status: FinalApplicationStatus, diagnostics: String = "") {
synchronized {
if (isFinished) {
@@ -351,7 +320,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
}
isFinished = true
- logInfo("finishApplicationMaster with " + status)
+ logInfo("Unregistering ApplicationMaster with " + status)
if (registered) {
val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "")
amClient.unregisterApplicationMaster(status, diagnostics, trackingUrl)
@@ -386,7 +355,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
def run() {
logInfo("AppMaster received a signal.")
- // we need to clean up staging dir before HDFS is shut down
+ // We need to clean up staging dir before HDFS is shut down
// make sure we don't delete it until this is the last AM
if (appMaster.isLastAMRetry) appMaster.cleanupStagingDir()
}
@@ -401,21 +370,24 @@ object ApplicationMaster {
// TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
// optimal as more containers are available. Might need to handle this better.
private val ALLOCATOR_LOOP_WAIT_COUNT = 30
+ private val ALLOCATE_HEARTBEAT_INTERVAL = 100
private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
val sparkContextRef: AtomicReference[SparkContext] =
- new AtomicReference[SparkContext](null /* initialValue */)
+ new AtomicReference[SparkContext](null)
- val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
+ // Variable used to notify the YarnClusterScheduler that it should stop waiting
+ // for the initial set of executors to be started and get on with its business.
+ val doneWithInitialAllocationsMonitor = new Object()
- def incrementAllocatorLoop(by: Int) {
- val count = yarnAllocatorLoop.getAndAdd(by)
- if (count >= ALLOCATOR_LOOP_WAIT_COUNT) {
- yarnAllocatorLoop.synchronized {
- // to wake threads off wait ...
- yarnAllocatorLoop.notifyAll()
- }
+ @volatile var isDoneWithInitialAllocations = false
+
+ def doneWithInitialAllocations() {
+ isDoneWithInitialAllocations = true
+ doneWithInitialAllocationsMonitor.synchronized {
+ // to wake threads off wait ...
+ doneWithInitialAllocationsMonitor.notifyAll()
}
}
@@ -423,7 +395,10 @@ object ApplicationMaster {
applicationMasters.add(master)
}
- // TODO(harvey): See whether this should be discarded - it isn't used anywhere atm...
+ /**
+ * Called from YarnClusterScheduler to notify the AM code that a SparkContext has been
+ * initialized in the user code.
+ */
def sparkContextInitialized(sc: SparkContext): Boolean = {
var modified = false
sparkContextRef.synchronized {
@@ -431,7 +406,7 @@ object ApplicationMaster {
sparkContextRef.notifyAll()
}
- // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do
+ // Add a shutdown hook - as a best effort in case users do not call sc.stop or do
// System.exit.
// Should not really have to do this, but it helps YARN to evict resources earlier.
// Not to mention, prevent the Client from declaring failure even though we exited properly.
@@ -454,13 +429,29 @@ object ApplicationMaster {
})
}
- // Wait for initialization to complete and atleast 'some' nodes can get allocated.
- yarnAllocatorLoop.synchronized {
- while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) {
- yarnAllocatorLoop.wait(1000L)
+ // Wait for initialization to complete and at least 'some' nodes to get allocated.
+ modified
+ }
+
+ /**
+ * Returns when we've either
+ * 1) received all the requested executors,
+ * 2) waited ALLOCATOR_LOOP_WAIT_COUNT * ALLOCATE_HEARTBEAT_INTERVAL ms,
+ * 3) hit an error that causes us to terminate trying to get containers.
+ */
+ def waitForInitialAllocations() {
+ doneWithInitialAllocationsMonitor.synchronized {
+ while (!isDoneWithInitialAllocations) {
+ doneWithInitialAllocationsMonitor.wait(1000L)
}
}
- modified
+ }
+
+ def getApplicationAttemptId(): ApplicationAttemptId = {
+ val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name())
+ val containerId = ConverterUtils.toContainerId(containerIdString)
+ val appAttemptId = containerId.getApplicationAttemptId()
+ appAttemptId
}
def main(argStrings: Array[String]) {
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 393edd1f2d670..80a8bceb17269 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -21,14 +21,12 @@ import java.nio.ByteBuffer
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.DataOutputBuffer
-import org.apache.hadoop.yarn.api._
-import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.YarnClient
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
-import org.apache.hadoop.yarn.util.{Apps, Records}
+import org.apache.hadoop.yarn.util.Records
import org.apache.spark.{Logging, SparkConf}
@@ -97,12 +95,11 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
def run() {
val appId = runApp()
monitorApplication(appId)
- System.exit(0)
}
def logClusterResourceDetails() {
val clusterMetrics: YarnClusterMetrics = yarnClient.getYarnClusterMetrics
- logInfo("Got Cluster metric info from ApplicationsManager (ASM), number of NodeManagers: " +
+ logInfo("Got Cluster metric info from ResourceManager, number of NodeManagers: " +
clusterMetrics.getNumNodeManagers)
val queueInfo: QueueInfo = yarnClient.getQueueInfo(args.amQueue)
@@ -133,7 +130,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
def submitApp(appContext: ApplicationSubmissionContext) = {
// Submit the application to the applications manager.
- logInfo("Submitting application to ASM")
+ logInfo("Submitting application to ResourceManager")
yarnClient.submitApplication(appContext)
}
@@ -149,7 +146,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
Thread.sleep(interval)
val report = yarnClient.getApplicationReport(appId)
- logInfo("Application report from ASM: \n" +
+ logInfo("Application report from ResourceManager: \n" +
"\t application identifier: " + appId.toString() + "\n" +
"\t appId: " + appId.getId() + "\n" +
"\t clientToAMToken: " + report.getClientToAMToken() + "\n" +
@@ -188,9 +185,18 @@ object Client {
// see Client#setupLaunchEnv().
System.setProperty("SPARK_YARN_MODE", "true")
val sparkConf = new SparkConf()
- val args = new ClientArguments(argStrings, sparkConf)
- new Client(args, sparkConf).run()
+ try {
+ val args = new ClientArguments(argStrings, sparkConf)
+ new Client(args, sparkConf).run()
+ } catch {
+ case e: Exception => {
+ Console.err.println(e.getMessage)
+ System.exit(1)
+ }
+ }
+
+ System.exit(0)
}
}
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
index d93e5bb0225d5..f71ad036ce0f2 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
@@ -72,8 +72,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
override def preStart() {
logInfo("Listen to driver: " + driverUrl)
driver = context.actorSelection(driverUrl)
- // Send a hello message thus the connection is actually established,
- // thus we can monitor Lifecycle Events.
+ // Send a hello message to establish the connection, after which
+ // we can monitor Lifecycle Events.
driver ! "Hello"
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
}
@@ -95,7 +95,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
amClient.init(yarnConf)
amClient.start()
- appAttemptId = getApplicationAttemptId()
+ appAttemptId = ApplicationMaster.getApplicationAttemptId()
registerApplicationMaster()
waitForSparkMaster()
@@ -115,7 +115,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
val interval = math.min(timeoutInterval / 2, schedulerInterval)
reporterThread = launchReporterThread(interval)
-
+
// Wait for the reporter thread to Finish.
reporterThread.join()
@@ -134,25 +134,16 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
// LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
.orElse(Option(System.getenv("LOCAL_DIRS")))
-
+
localDirs match {
case None => throw new Exception("Yarn Local dirs can't be empty")
case Some(l) => l
}
- }
-
- private def getApplicationAttemptId(): ApplicationAttemptId = {
- val envs = System.getenv()
- val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name())
- val containerId = ConverterUtils.toContainerId(containerIdString)
- val appAttemptId = containerId.getApplicationAttemptId()
- logInfo("ApplicationAttemptId: " + appAttemptId)
- appAttemptId
}
private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
logInfo("Registering the ApplicationMaster")
- // TODO:(Raymond) Find out Spark UI address and fill in here?
+ // TODO: Find out client's Spark UI address and fill in here?
amClient.registerApplicationMaster(Utils.localHostName(), 0, "")
}
@@ -185,8 +176,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
private def allocateExecutors() {
-
- // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now.
+ // TODO: should get preferredNodeLocationData from SparkContext, just fake a empty one for now.
val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] =
scala.collection.immutable.Map()
@@ -198,8 +188,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
preferredNodeLocationData,
sparkConf)
- logInfo("Allocating " + args.numExecutors + " executors.")
- // Wait until all containers have finished
+ logInfo("Requesting " + args.numExecutors + " executors.")
+ // Wait until all containers have launched
yarnAllocator.addResourceRequests(args.numExecutors)
yarnAllocator.allocateResources()
while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) {
@@ -221,7 +211,6 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
}
}
- // TODO: We might want to extend this to allocate more containers in case they die !
private def launchReporterThread(_sleepTime: Long): Thread = {
val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime
@@ -229,7 +218,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
override def run() {
while (!driverClosed) {
allocateMissingExecutor()
- sendProgress()
+ logDebug("Sending progress")
+ yarnAllocator.allocateResources()
Thread.sleep(sleepTime)
}
}
@@ -241,20 +231,14 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
t
}
- private def sendProgress() {
- logDebug("Sending progress")
- // simulated with an allocate request with no nodes requested ...
- yarnAllocator.allocateResources()
- }
-
def finishApplicationMaster(status: FinalApplicationStatus) {
- logInfo("finish ApplicationMaster with " + status)
- amClient.unregisterApplicationMaster(status, "" /* appMessage */ , "" /* appTrackingUrl */)
+ logInfo("Unregistering ApplicationMaster with " + status)
+ val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "")
+ amClient.unregisterApplicationMaster(status, "" /* appMessage */ , trackingUrl)
}
}
-
object ExecutorLauncher {
def main(argStrings: Array[String]) {
val args = new ApplicationMasterArguments(argStrings)