Skip to content

Commit

Permalink
WIP fixing 1.1 merge
Browse files Browse the repository at this point in the history
  • Loading branch information
giwa committed Sep 21, 2014
1 parent 5cdb6fa commit 550dfd9
Showing 1 changed file with 62 additions and 14 deletions.
76 changes: 62 additions & 14 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.streaming.util import rddToFileName, RDDFunction
from pyspark.rdd import portable_hash, _parse_memory
from pyspark.traceback_utils import SCCallSiteSync

from py4j.java_collections import ListConverter, MapConverter
Expand All @@ -40,6 +41,7 @@ def __init__(self, jdstream, ssc, jrdd_deserializer):
self._jrdd_deserializer = jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False
self._partitionFunc = None

def context(self):
"""
Expand Down Expand Up @@ -161,32 +163,71 @@ def _mergeCombiners(iterator):

return shuffled.mapPartitions(_mergeCombiners)

def partitionBy(self, numPartitions, partitionFunc=None):
def partitionBy(self, numPartitions, partitionFunc=portable_hash):
"""
Return a copy of the DStream partitioned using the specified partitioner.
"""
if numPartitions is None:
numPartitions = self.ctx._defaultReducePartitions()

if partitionFunc is None:
partitionFunc = lambda x: 0 if x is None else hash(x)

# Transferring O(n) objects to Java is too expensive. Instead, we'll
# form the hash buckets in Python, transferring O(numPartitions) objects
# to Java. Each object is a (splitNumber, [objects]) pair.

outputSerializer = self.ctx._unbatched_serializer
#
# def add_shuffle_key(split, iterator):
# buckets = defaultdict(list)
#
# for (k, v) in iterator:
# buckets[partitionFunc(k) % numPartitions].append((k, v))
# for (split, items) in buckets.iteritems():
# yield pack_long(split)
# yield outputSerializer.dumps(items)
# keyed = PipelinedDStream(self, add_shuffle_key)

limit = (_parse_memory(self.ctx._conf.get(
"spark.python.worker.memory", "512m")) / 2)

def add_shuffle_key(split, iterator):

buckets = defaultdict(list)
c, batch = 0, min(10 * numPartitions, 1000)

for (k, v) in iterator:
for k, v in iterator:
buckets[partitionFunc(k) % numPartitions].append((k, v))
for (split, items) in buckets.iteritems():
c += 1

# check used memory and avg size of chunk of objects
if (c % 1000 == 0 and get_used_memory() > limit
or c > batch):
n, size = len(buckets), 0
for split in buckets.keys():
yield pack_long(split)
d = outputSerializer.dumps(buckets[split])
del buckets[split]
yield d
size += len(d)

avg = (size / n) >> 20
# let 1M < avg < 10M
if avg < 1:
batch *= 1.5
elif avg > 10:
batch = max(batch / 1.5, 1)
c = 0

for split, items in buckets.iteritems():
yield pack_long(split)
yield outputSerializer.dumps(items)
keyed = PipelinedDStream(self, add_shuffle_key)

keyed = self._mapPartitionsWithIndex(add_shuffle_key)




keyed._bypass_serializer = True
with SCCallSiteSync(self.context) as css:
with SCCallSiteSync(self.ctx) as css:
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jdstream = self.ctx._jvm.PythonPairwiseDStream(keyed._jdstream.dstream(),
Expand Down Expand Up @@ -428,6 +469,10 @@ def get_output(rdd, time):


class PipelinedDStream(DStream):
"""
Since PipelinedDStream is same to PipelindRDD, if PipliedRDD is changed,
this code should be changed in the same way.
"""
def __init__(self, prev, func, preservesPartitioning=False):
if not isinstance(prev, PipelinedDStream) or not prev._is_pipelinable():
# This transformation is the first in its stage:
Expand All @@ -453,19 +498,22 @@ def pipeline_func(split, iterator):
self._jdstream_val = None
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None

@property
def _jdstream(self):
if self._jdstream_val:
return self._jdstream_val
if self._bypass_serializer:
serializer = NoOpSerializer()
else:
serializer = self.ctx.serializer

command = (self.func, self._prev_jrdd_deserializer, serializer)
ser = CompressedSerializer(CloudPickleSerializer())
self.jrdd_deserializer = NoOpSerializer()
command = (self.func, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if pickled_command > (1 << 20): # 1M
broadcast = self.ctx.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
Expand Down

0 comments on commit 550dfd9

Please sign in to comment.