Skip to content

Commit

Permalink
address comments, rollback changes in ResultIterable
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Apr 8, 2015
1 parent e3b8eab commit 0dcf320
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 31 deletions.
5 changes: 0 additions & 5 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,11 +1765,6 @@ def groupByKey(self, numPartitions=None):
Group the values for each key in the RDD into a single sequence.
Hash-partitions the resulting RDD with numPartitions partitions.
The values in the resulting RDD is iterable object L{ResultIterable},
they can be iterated only once. The `len(values)` will result in
iterating values, so they can not be iterable after calling
`len(values)`.
Note: If you are grouping in order to perform an aggregation (such as a
sum or average) over each key, using reduceByKey or aggregateByKey will
provide much better performance.
Expand Down
17 changes: 9 additions & 8 deletions python/pyspark/resultiterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,25 @@
# limitations under the License.
#

import collections

__all__ = ["ResultIterable"]


class ResultIterable(object):
class ResultIterable(collections.Iterable):

"""
A special result iterable. This is used because the standard
iterator can not be pickled
"""

def __init__(self, it):
self.it = it
def __init__(self, data):
self.data = data
self.index = 0
self.maxindex = len(data)

def __iter__(self):
return iter(self.it)
return iter(self.data)

def __len__(self):
try:
return len(self.it)
except TypeError:
return sum(1 for _ in self.it)
return len(self.data)
4 changes: 2 additions & 2 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __repr__(self):
return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)


class FlattedValuesSerializer(BatchedSerializer):
class FlattenedValuesSerializer(BatchedSerializer):

"""
Serializes a stream of list of pairs, split the list of values
Expand All @@ -240,7 +240,7 @@ def load_stream(self, stream):
return self.serializer.load_stream(stream)

def __repr__(self):
return "FlattedValuesSerializer(%d)" % self.batchSize
return "FlattenedValuesSerializer(%d)" % self.batchSize


class AutoBatchedSerializer(BatchedSerializer):
Expand Down
38 changes: 23 additions & 15 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import random

import pyspark.heapq3 as heapq
from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattedValuesSerializer, \
from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
CompressedSerializer, AutoBatchedSerializer


Expand Down Expand Up @@ -372,8 +372,10 @@ def iteritems(self):

def _external_items(self):
""" Return all partitioned items as iterator """
assert not self.data
if any(self.pdata):
self._spill()
# disable partitioning and spilling when merge combiners from disk
self.pdata = []

try:
Expand Down Expand Up @@ -546,7 +548,10 @@ class ExternalList(object):

def __init__(self, values):
self.values = values
self.disk_count = 0
if values and isinstance(values[0], list):
self.count = sum(len(i) for i in values)
else:
self.count = len(values)
self._file = None
self._ser = None

Expand All @@ -555,16 +560,16 @@ def __getstate__(self):
self._file.flush()
f = os.fdopen(os.dup(self._file.fileno()))
f.seek(0)
bytes = f.read()
serialized = f.read()
else:
bytes = ''
return self.values, self.disk_count, bytes
serialized = ''
return self.values, self.count, serialized

def __setstate__(self, item):
self.values, self.disk_count, bytes = item
if bytes:
self.values, self.count, serialized = item
if serialized:
self._open_file()
self._file.write(bytes)
self._file.write(serialized)
else:
self._file = None
self._ser = None
Expand All @@ -583,10 +588,11 @@ def __iter__(self):
yield v

def __len__(self):
return self.disk_count + len(self.values)
return self.count

def append(self, value):
self.values.append(value)
self.count += len(value) if isinstance(value, list) else 1
# dump them into disk if the key is huge
if len(self.values) >= self.LIMIT:
self._spill()
Expand All @@ -610,7 +616,6 @@ def _spill(self):
used_memory = get_used_memory()
pos = self._file.tell()
self._ser.dump_stream([self.values], self._file)
self.disk_count += len(self.values)
self.values = []
gc.collect()
DiskBytesSpilled += self._file.tell() - pos
Expand Down Expand Up @@ -657,7 +662,10 @@ def __init__(self, iterators):
self.iterators = iterators

def __len__(self):
return sum(len(vs) for vs in self.iterators)
try:
return len(self.iterators)
except:
return sum(len(i) for i in self.iterators)

def __iter__(self):
return itertools.chain.from_iterable(self.iterators)
Expand Down Expand Up @@ -702,10 +710,10 @@ class ExternalGroupBy(ExternalMerger):
"""
SORT_KEY_LIMIT = 1000

def _flatted_serializer(self):
def flattened_serializer(self):
assert isinstance(self.serializer, BatchedSerializer)
ser = self.serializer
return FlattedValuesSerializer(ser, 20)
return FlattenedValuesSerializer(ser, 20)

def _object_size(self, obj):
return len(obj)
Expand Down Expand Up @@ -734,7 +742,7 @@ def _spill(self):
# sort them before dumping into disks
self._sorted = len(self.data) < self.SORT_KEY_LIMIT
if self._sorted:
self.serializer = self._flatted_serializer()
self.serializer = self.flattened_serializer()
for k in sorted(self.data.keys()):
h = self._partition(k)
self.serializer.dump_stream([(k, self.data[k])], streams[h])
Expand Down Expand Up @@ -802,7 +810,7 @@ def load_partition(j):
else:
# Flatten the combined values, so it will not consume huge
# memory during merging sort.
ser = self._flatted_serializer()
ser = self.flattened_serializer()
sorter = ExternalSorter(self.memory_limit, ser)
sorted_items = sorter.sorted(itertools.chain(*disk_items),
key=operator.itemgetter(0))
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,9 +741,11 @@ def test_external_group_by_key(self):
filtered = gkv.filter(lambda (k, vs): k == 1)
self.assertEqual(1, filtered.count())
self.assertEqual([(1, N/3)], filtered.mapValues(len).collect())
self.assertEqual([(N/3, N/3)],
filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
result = filtered.collect()[0][1]
self.assertEqual(N/3, len(result))
self.assertTrue(isinstance(result.it, shuffle.ChainedIterable))
self.assertTrue(isinstance(result.data, shuffle.ChainedIterable))

def test_sort_on_empty_rdd(self):
self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
Expand Down

0 comments on commit 0dcf320

Please sign in to comment.