Skip to content

Commit 3e2492b

Browse files
committedOct 10, 2014
change updateStateByKey() to easy API
1 parent 182be73 commit 3e2492b

File tree

3 files changed

+72
-17
lines changed

3 files changed

+72
-17
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Counts words in UTF8 encoded, '\n' delimited text received from the
20+
network every second.
21+
22+
Usage: stateful_network_wordcount.py <hostname> <port>
23+
<hostname> and <port> describe the TCP server that Spark Streaming
24+
would connect to receive data.
25+
26+
To run this on your local machine, you need to first run a Netcat server
27+
`$ nc -lk 9999`
28+
and then run the example
29+
`$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \
30+
localhost 9999`
31+
"""
32+
33+
import sys
34+
35+
from pyspark import SparkContext
36+
from pyspark.streaming import StreamingContext
37+
38+
if __name__ == "__main__":
39+
if len(sys.argv) != 3:
40+
print >> sys.stderr, "Usage: stateful_network_wordcount.py <hostname> <port>"
41+
exit(-1)
42+
sc = SparkContext(appName="PythonStreamingNetworkWordCount")
43+
ssc = StreamingContext(sc, 1)
44+
ssc.checkpoint("checkpoint")
45+
46+
def updateFunc(new_values, last_sum):
47+
return sum(new_values) + (last_sum or 0)
48+
49+
lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
50+
running_counts = lines.flatMap(lambda line: line.split(" "))\
51+
.map(lambda word: (word, 1))\
52+
.updateStateByKey(updateFunc)
53+
54+
running_counts.pprint()
55+
56+
ssc.start()
57+
ssc.awaitTermination()

‎python/pyspark/streaming/dstream.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -564,19 +564,19 @@ def updateStateByKey(self, updateFunc, numPartitions=None):
564564
Return a new "state" DStream where the state for each key is updated by applying
565565
the given function on the previous state of the key and the new values of the key.
566566
567-
@param updateFunc: State update function ([(k, vs, s)] -> [(k, s)]).
568-
If `s` is None, then `k` will be eliminated.
567+
@param updateFunc: State update function. If this function returns None, then
568+
corresponding state key-value pair will be eliminated.
569569
"""
570570
if numPartitions is None:
571571
numPartitions = self._sc.defaultParallelism
572572

573573
def reduceFunc(t, a, b):
574574
if a is None:
575-
g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None))
575+
g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
576576
else:
577577
g = a.cogroup(b, numPartitions)
578-
g = g.map(lambda (k, (va, vb)): (k, list(vb), list(va)[0] if len(va) else None))
579-
state = g.mapPartitions(lambda x: updateFunc(x))
578+
g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None))
579+
state = g.mapValues(lambda (vs, s): updateFunc(vs, s))
580580
return state.filter(lambda (k, v): v is not None)
581581

582582
jreduceFunc = TransformFunction(self._sc, reduceFunc,

‎python/pyspark/streaming/tests.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _sort_result_based_on_key(self, outputs):
119119
output.sort(key=lambda x: x[0])
120120

121121

122-
class TestBasicOperations(PySparkStreamingTestCase):
122+
class BasicOperationTests(PySparkStreamingTestCase):
123123

124124
def test_map(self):
125125
"""Basic operation test for DStream.map."""
@@ -340,15 +340,13 @@ def func(a, b):
340340
expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]]
341341
self._test_func(input, func, expected, True, input2)
342342

343-
def update_state_by_key(self):
343+
def test_update_state_by_key(self):
344344

345-
def updater(it):
346-
for k, vs, s in it:
347-
if not s:
348-
s = vs
349-
else:
350-
s.extend(vs)
351-
yield (k, s)
345+
def updater(vs, s):
346+
if not s:
347+
s = []
348+
s.extend(vs)
349+
return s
352350

353351
input = [[('k', i)] for i in range(5)]
354352

@@ -360,7 +358,7 @@ def func(dstream):
360358
self._test_func(input, func, expected)
361359

362360

363-
class TestWindowFunctions(PySparkStreamingTestCase):
361+
class WindowFunctionTests(PySparkStreamingTestCase):
364362

365363
timeout = 20
366364

@@ -417,7 +415,7 @@ def test_reduce_by_invalid_window(self):
417415
self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1))
418416

419417

420-
class TestStreamingContext(PySparkStreamingTestCase):
418+
class StreamingContextTests(PySparkStreamingTestCase):
421419

422420
duration = 0.1
423421

@@ -480,7 +478,7 @@ def func(rdds):
480478
self.assertEqual([2, 3, 1], self._take(dstream, 3))
481479

482480

483-
class TestCheckpoint(PySparkStreamingTestCase):
481+
class CheckpointTests(PySparkStreamingTestCase):
484482

485483
def setUp(self):
486484
pass

0 commit comments

Comments
 (0)