Skip to content

Commit

Permalink
rollback RDD.setContext(), use textFileStream() to test checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Oct 1, 2014
1 parent bd8a4c2 commit 7a88f9f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag](

private[spark] class ParallelCollectionRDD[T: ClassTag](
@transient sc: SparkContext,
data: Seq[T],
@transient data: Seq[T],
numSlices: Int,
locationPrefs: Map[Int, Seq[String]])
extends RDD[T](sc, Nil) {
Expand Down
8 changes: 0 additions & 8 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,6 @@ abstract class RDD[T: ClassTag](
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))

// setContext after loading from checkpointing
private[spark] def setContext(s: SparkContext) = {
if (sc != null && sc != s) {
throw new SparkException("Context is already set in " + this + ", cannot set it again")
}
sc = s
}

private[spark] def conf = sc.conf
// =======================================================================
// Methods that should be implemented by subclasses of RDD
Expand Down
52 changes: 27 additions & 25 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def _collect(self, dstream):

def get_output(_, rdd):
r = rdd.collect()
result.append(r)
if r:
result.append(r)
dstream.foreachRDD(get_output)
return result

Expand Down Expand Up @@ -449,24 +450,18 @@ def test_queueStream(self):
time.sleep(1)
self.assertEqual(input, result[:3])

# TODO: fix this test
# def test_textFileStream(self):
# input = [range(i) for i in range(3)]
# dstream = self.ssc.queueStream(input)
# d = os.path.join(tempfile.gettempdir(), str(id(self)))
# if not os.path.exists(d):
# os.makedirs(d)
# dstream.saveAsTextFiles(os.path.join(d, 'test'))
# self.ssc.start()
# time.sleep(1)
# self.ssc.stop(False, True)
#
# self.ssc = StreamingContext(self.sc, self.batachDuration)
# dstream2 = self.ssc.textFileStream(d)
# result = self._collect(dstream2)
# self.ssc.start()
# time.sleep(2)
# self.assertEqual(input, result[:3])
def test_textFileStream(self):
d = tempfile.mkdtemp()
self.ssc = StreamingContext(self.sc, self.duration)
dstream2 = self.ssc.textFileStream(d).map(int)
result = self._collect(dstream2)
self.ssc.start()
time.sleep(1)
for name in ('a', 'b'):
with open(os.path.join(d, name), "w") as f:
f.writelines(["%d\n" % i for i in range(10)])
time.sleep(2)
self.assertEqual([range(10) * 2], result[:3])

def test_union(self):
input = [range(i) for i in range(3)]
Expand Down Expand Up @@ -503,27 +498,34 @@ def tearDown(self):

def test_get_or_create(self):
result = [0]
inputd = tempfile.mkdtemp()

def setup():
conf = SparkConf().set("spark.default.parallelism", 1)
sc = SparkContext(conf=conf)
ssc = StreamingContext(sc, .2)
rdd = sc.parallelize(range(1), 1)
dstream = ssc.queueStream([rdd], default=rdd)
result[0] = self._collect(dstream.countByWindow(1, 0.2))
dstream = ssc.textFileStream(inputd)
result[0] = self._collect(dstream.count())
return ssc

tmpd = tempfile.mkdtemp("test_streaming_cps")
ssc = StreamingContext.getOrCreate(tmpd, setup)
ssc.start()
time.sleep(1)
with open(os.path.join(inputd, "1"), 'w') as f:
f.writelines(["%d\n" % i for i in range(10)])
ssc.awaitTermination(4)
ssc.stop()
ssc.stop(True, True)
expected = [[i * 1 + 1] for i in range(5)] + [[5]] * 5
self.assertEqual(expected, result[0][:10])
self.assertEqual([[10]], result[0][:1])

ssc = StreamingContext.getOrCreate(tmpd, setup)
ssc.start()
time.sleep(1)
with open(os.path.join(inputd, "1"), 'w') as f:
f.writelines(["%d\n" % i for i in range(10)])
ssc.awaitTermination(2)
ssc.stop()
ssc.stop(True, True)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.streaming.dstream

import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.UnionRDD
import scala.collection.mutable.Queue
Expand All @@ -33,12 +32,6 @@ class QueueInputDStream[T: ClassTag](
defaultRDD: RDD[T]
) extends InputDStream[T](ssc) {

private[streaming] override def setContext(s: StreamingContext) {
super.setContext(s)
queue.map(_.setContext(s.sparkContext))
defaultRDD.setContext(s.sparkContext)
}

override def start() { }

override def stop() { }
Expand Down

0 comments on commit 7a88f9f

Please sign in to comment.