Skip to content

Commit eec401e

Browse files
committedSep 26, 2014
refactor, combine TransformedRDD, fix reuse PythonRDD, fix union
1 parent 9a57685 commit eec401e

File tree

5 files changed

+178
-76
lines changed

5 files changed

+178
-76
lines changed
 

‎python/pyspark/streaming/context.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
import sys
19-
from signal import signal, SIGTERM, SIGINT
20-
import atexit
21-
import time
22-
23-
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
18+
from pyspark.serializers import UTF8Deserializer
2419
from pyspark.context import SparkContext
2520
from pyspark.streaming.dstream import DStream
2621
from pyspark.streaming.duration import Duration, Seconds

‎python/pyspark/streaming/dstream.py

+76-36
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,15 @@
1515
# limitations under the License.
1616
#
1717

18-
from collections import defaultdict
1918
from itertools import chain, ifilter, imap
2019
import operator
2120

2221
from pyspark import RDD
23-
from pyspark.serializers import NoOpSerializer,\
24-
BatchedSerializer, CloudPickleSerializer, pack_long,\
25-
CompressedSerializer
2622
from pyspark.storagelevel import StorageLevel
27-
from pyspark.resultiterable import ResultIterable
28-
from pyspark.streaming.util import rddToFileName, RDDFunction
29-
from pyspark.rdd import portable_hash, _parse_memory
30-
from pyspark.traceback_utils import SCCallSiteSync
23+
from pyspark.streaming.util import rddToFileName, RDDFunction, RDDFunction2
24+
from pyspark.rdd import portable_hash
25+
from pyspark.streaming.duration import Seconds
3126

32-
from py4j.java_collections import ListConverter, MapConverter
3327

3428
__all__ = ["DStream"]
3529

@@ -42,7 +36,6 @@ def __init__(self, jdstream, ssc, jrdd_deserializer):
4236
self._jrdd_deserializer = jrdd_deserializer
4337
self.is_cached = False
4438
self.is_checkpointed = False
45-
self._partitionFunc = None
4639

4740
def context(self):
4841
"""
@@ -159,7 +152,7 @@ def foreachRDD(self, func):
159152
This is an output operator, so this DStream will be registered as an output
160153
stream and there materialized.
161154
"""
162-
jfunc = RDDFunction(self.ctx, lambda a, b, t: func(a, t), self._jrdd_deserializer)
155+
jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer)
163156
self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), jfunc)
164157

165158
def pyprint(self):
@@ -306,19 +299,19 @@ def get_output(rdd, time):
306299
return result
307300

308301
def transform(self, func):
309-
return TransformedRDD(self, lambda a, b, t: func(a), cache=True)
310-
311-
def transformWith(self, func, other):
312-
return TransformedRDD(self, lambda a, b, t: func(a, b), other)
302+
return TransformedRDD(self, lambda a, t: func(a), True)
313303

314304
def transformWithTime(self, func):
315-
return TransformedRDD(self, lambda a, b, t: func(a, t))
305+
return TransformedRDD(self, func, False)
306+
307+
def transformWith(self, func, other, keepSerializer=False):
308+
return Transformed2RDD(self, lambda a, b, t: func(a, b), other, keepSerializer)
316309

317310
def repartitions(self, numPartitions):
318311
return self.transform(lambda rdd: rdd.repartition(numPartitions))
319312

320313
def union(self, other):
321-
return self.transformWith(lambda a, b: a.union(b), other)
314+
return self.transformWith(lambda a, b: a.union(b), other, True)
322315

323316
def cogroup(self, other):
324317
return self.transformWith(lambda a, b: a.cogroup(b), other)
@@ -329,32 +322,79 @@ def leftOuterJoin(self, other):
329322
def rightOuterJoin(self, other):
330323
return self.transformWith(lambda a, b: a.rightOuterJoin(b), other)
331324

332-
def slice(self, fromTime, toTime):
333-
jrdds = self._jdstream.slice(fromTime._jtime, toTime._jtime)
334-
# FIXME: serializer
335-
return [RDD(jrdd, self.ctx, self.ctx.serializer) for jrdd in jrdds]
325+
def _jtime(self, milliseconds):
326+
return self.ctx._jvm.Time(milliseconds)
327+
328+
def slice(self, begin, end):
329+
jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end))
330+
return [RDD(jrdd, self.ctx, self._jrdd_deserializer) for jrdd in jrdds]
331+
332+
def window(self, windowDuration, slideDuration=None):
333+
d = Seconds(windowDuration)
334+
if slideDuration is None:
335+
return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer)
336+
s = Seconds(slideDuration)
337+
return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer)
338+
339+
def reduceByWindow(self, reduceFunc, inReduceFunc, windowDuration, slideDuration):
340+
pass
341+
342+
def countByWindow(self, window, slide):
343+
pass
344+
345+
def countByValueAndWindow(self, window, slide, numPartitions=None):
346+
pass
347+
348+
def groupByKeyAndWindow(self, window, slide, numPartitions=None):
349+
pass
350+
351+
def reduceByKeyAndWindow(self, reduceFunc, inReduceFunc, window, slide, numPartitions=None):
352+
pass
336353

337354
def updateStateByKey(self, updateFunc):
338355
# FIXME: convert updateFunc to java JFunction2
339356
jFunc = updateFunc
340357
return self._jdstream.updateStateByKey(jFunc)
341358

342359

343-
# Window Operations
344-
# TODO: implement window
345-
# TODO: implement groupByKeyAndWindow
346-
# TODO: implement reduceByKeyAndWindow
347-
# TODO: implement countByValueAndWindow
348-
# TODO: implement countByWindow
349-
# TODO: implement reduceByWindow
360+
class TransformedRDD(DStream):
361+
def __init__(self, prev, func, reuse=False):
362+
ssc = prev._ssc
363+
self._ssc = ssc
364+
self.ctx = ssc._sc
365+
self._jrdd_deserializer = self.ctx.serializer
366+
self.is_cached = False
367+
self.is_checkpointed = False
368+
369+
if isinstance(prev, TransformedRDD) and not prev.is_cached and not prev.is_checkpointed:
370+
prev_func = prev.func
371+
old_func = func
372+
func = lambda rdd, t: old_func(prev_func(rdd, t), t)
373+
reuse = reuse and prev.reuse
374+
prev = prev.prev
350375

376+
self.prev = prev
377+
self.func = func
378+
self.reuse = reuse
379+
self._jdstream_val = None
351380

352-
class TransformedRDD(DStream):
353-
# TODO: better name for cache
354-
def __init__(self, prev, func, other=None, cache=False):
355-
# TODO: combine transformed RDD
381+
@property
382+
def _jdstream(self):
383+
if self._jdstream_val is not None:
384+
return self._jdstream_val
385+
386+
jfunc = RDDFunction(self.ctx, self.func, self.prev._jrdd_deserializer)
387+
jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(),
388+
jfunc, self.reuse).asJavaDStream()
389+
self._jdstream_val = jdstream
390+
return jdstream
391+
392+
393+
class Transformed2RDD(DStream):
394+
def __init__(self, prev, func, other, keepSerializer=False):
356395
ssc = prev._ssc
357-
t = RDDFunction(ssc._sc, func, prev._jrdd_deserializer)
358-
jdstream = ssc._jvm.PythonTransformedDStream(prev._jdstream.dstream(),
359-
other and other._jdstream, t, cache)
360-
DStream.__init__(self, jdstream.asJavaDStream(), ssc, ssc._sc.serializer)
396+
jfunc = RDDFunction2(ssc._sc, func, prev._jrdd_deserializer)
397+
jdstream = ssc._jvm.PythonTransformed2DStream(prev._jdstream.dstream(),
398+
other._jdstream.dstream(), jfunc)
399+
jrdd_serializer = prev._jrdd_deserializer if keepSerializer else ssc._sc.serializer
400+
DStream.__init__(self, jdstream.asJavaDStream(), ssc, jrdd_serializer)

‎python/pyspark/streaming/tests.py

+26
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,32 @@ def add(a, b):
213213
[("a", "11"), ("b", "1"), ("", "111")]]
214214
self._test_func(input, func, expected, sort=True)
215215

216+
def test_union(self):
217+
input1 = [range(3), range(5), range(1)]
218+
input2 = [range(3, 6), range(5, 6), range(1, 6)]
219+
220+
d1 = self.ssc._makeStream(input1)
221+
d2 = self.ssc._makeStream(input2)
222+
d = d1.union(d2)
223+
result = d.collect()
224+
expected = [range(6), range(6), range(6)]
225+
226+
self.ssc.start()
227+
start_time = time.time()
228+
# Loop until get the expected the number of the result from the stream.
229+
while True:
230+
current_time = time.time()
231+
# Check time out.
232+
if (current_time - start_time) > self.timeout * 2:
233+
break
234+
# StreamingContext.awaitTermination is not used to wait because
235+
# if py4j server is called every 50 milliseconds, it gets an error.
236+
time.sleep(0.05)
237+
# Check if the output is the same length of expected output.
238+
if len(expected) == len(result):
239+
break
240+
self.assertEqual(expected, result)
241+
216242
def _sort_result_based_on_key(self, outputs):
217243
"""Sort the list base onf first value."""
218244
for output in outputs:

‎python/pyspark/streaming/util.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,36 @@ def __init__(self, ctx, func, jrdd_deserializer):
2828
self.func = func
2929
self.deserializer = jrdd_deserializer
3030

31-
def call(self, jrdd, jrdd2, milliseconds):
31+
def call(self, jrdd, milliseconds):
3232
try:
3333
rdd = RDD(jrdd, self.ctx, self.deserializer)
34+
r = self.func(rdd, milliseconds)
35+
if r:
36+
return r._jrdd
37+
except:
38+
import traceback
39+
traceback.print_exc()
40+
41+
def __repr__(self):
42+
return "RDDFunction(%s, %s)" % (str(self.deserializer), str(self.func))
43+
44+
class Java:
45+
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']
46+
47+
48+
class RDDFunction2(object):
49+
"""
50+
This class is for py4j callback. This class is related with
51+
org.apache.spark.streaming.api.python.PythonRDDFunction2.
52+
"""
53+
def __init__(self, ctx, func, jrdd_deserializer):
54+
self.ctx = ctx
55+
self.func = func
56+
self.deserializer = jrdd_deserializer
57+
58+
def call(self, jrdd, jrdd2, milliseconds):
59+
try:
60+
rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else None
3461
other = RDD(jrdd2, self.ctx, self.deserializer) if jrdd2 else None
3562
r = self.func(rdd, other, milliseconds)
3663
if r:
@@ -43,7 +70,7 @@ def __repr__(self):
4370
return "RDDFunction(%s, %s)" % (str(self.deserializer), str(self.func))
4471

4572
class Java:
46-
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']
73+
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction2']
4774

4875

4976
def rddToFileName(prefix, suffix, time):

‎streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala

+46-32
Original file line numberDiff line numberDiff line change
@@ -28,77 +28,91 @@ import org.apache.spark.streaming.api.java._
2828

2929

3030
/**
31-
* Interface for Python callback function
31+
* Interface for Python callback function with two arguments
3232
*/
3333
trait PythonRDDFunction {
34-
def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]]
34+
def call(rdd: JavaRDD[_], time: Long): JavaRDD[Array[Byte]]
3535
}
3636

37+
/**
38+
* Interface for Python callback function with three arguments
39+
*/
40+
trait PythonRDDFunction2 {
41+
def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]]
42+
}
3743

3844
/**
3945
* Transformed DStream in Python.
4046
*
4147
* If the result RDD is PythonRDD, then it will cache it as an template for future use,
4248
* this can reduce the Python callbacks.
43-
*
44-
* @param parent
45-
* @param parent2
46-
* @param func
47-
* @param cache
4849
*/
49-
class PythonTransformedDStream (parent: DStream[_], parent2: DStream[_], func: PythonRDDFunction,
50-
cache: Boolean = false)
50+
class PythonTransformedDStream (parent: DStream[_], func: PythonRDDFunction,
51+
var reuse: Boolean = false)
5152
extends DStream[Array[Byte]] (parent.ssc) {
5253

5354
var lastResult: PythonRDD = _
5455

55-
override def dependencies = {
56-
if (parent2 == null) {
57-
List(parent)
58-
} else {
59-
List(parent, parent2)
60-
}
61-
}
56+
override def dependencies = List(parent)
6257

6358
override def slideDuration: Duration = parent.slideDuration
6459

6560
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
6661
val rdd1 = parent.getOrCompute(validTime).getOrElse(null)
67-
val rdd2 = if (parent2 != null) parent2.getOrCompute(validTime).getOrElse(null) else null
68-
69-
val r = if (rdd2 != null) {
70-
func.call(JavaRDD.fromRDD(rdd1), JavaRDD.fromRDD(rdd2), validTime.milliseconds)
71-
} else if (cache && lastResult != null) {
72-
lastResult.copyTo(rdd1).asJavaRDD
62+
if (reuse && lastResult != null) {
63+
Some(lastResult.copyTo(rdd1))
7364
} else {
74-
func.call(JavaRDD.fromRDD(rdd1), null, validTime.milliseconds)
75-
}
76-
if (r != null) {
77-
if (lastResult == null && r.isInstanceOf[PythonRDD]) {
78-
lastResult = r.asInstanceOf[PythonRDD]
65+
val r = func.call(JavaRDD.fromRDD(rdd1), validTime.milliseconds).rdd
66+
if (reuse && lastResult == null) {
67+
r match {
68+
case rdd: PythonRDD =>
69+
if (rdd.parent(0) == rdd1) {
70+
// only one PythonRDD
71+
lastResult = rdd
72+
} else {
73+
// may have multiple stages
74+
reuse = false
75+
}
76+
}
7977
}
8078
Some(r)
81-
} else {
82-
None
8379
}
8480
}
8581

8682
val asJavaDStream = JavaDStream.fromDStream(this)
8783
}
8884

85+
/**
86+
* Transformed from two DStreams in Python.
87+
*/
88+
class PythonTransformed2DStream (parent: DStream[_], parent2: DStream[_], func: PythonRDDFunction2)
89+
extends DStream[Array[Byte]] (parent.ssc) {
90+
91+
override def dependencies = List(parent, parent2)
92+
93+
override def slideDuration: Duration = parent.slideDuration
94+
95+
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
96+
def resultRdd(stream: DStream[_]): JavaRDD[_] = stream.getOrCompute(validTime) match {
97+
case Some(rdd) => JavaRDD.fromRDD(rdd)
98+
case None => null
99+
}
100+
Some(func.call(resultRdd(parent), resultRdd(parent2), validTime.milliseconds))
101+
}
102+
103+
val asJavaDStream = JavaDStream.fromDStream(this)
104+
}
89105

90106
/**
91107
* This is used for foreachRDD() in Python
92-
* @param prev
93-
* @param foreachFunction
94108
*/
95109
class PythonForeachDStream(
96110
prev: DStream[Array[Byte]],
97111
foreachFunction: PythonRDDFunction
98112
) extends ForEachDStream[Array[Byte]](
99113
prev,
100114
(rdd: RDD[Array[Byte]], time: Time) => {
101-
foreachFunction.call(rdd.toJavaRDD(), null, time.milliseconds)
115+
foreachFunction.call(rdd.toJavaRDD(), time.milliseconds)
102116
}
103117
) {
104118

0 commit comments

Comments
 (0)