diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 829c8c9b14481..1b18789040360 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1765,11 +1765,6 @@ def groupByKey(self, numPartitions=None): Group the values for each key in the RDD into a single sequence. Hash-partitions the resulting RDD with numPartitions partitions. - The values in the resulting RDD is iterable object L{ResultIterable}, - they can be iterated only once. The `len(values)` will result in - iterating values, so they can not be iterable after calling - `len(values)`. - Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will provide much better performance. diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index 7093c938ee6fd..1ab5ce14c3531 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -15,24 +15,25 @@ # limitations under the License. # +import collections + __all__ = ["ResultIterable"] -class ResultIterable(object): +class ResultIterable(collections.Iterable): """ A special result iterable. This is used because the standard iterator can not be pickled """ - def __init__(self, it): - self.it = it + def __init__(self, data): + self.data = data + self.index = 0 + self.maxindex = len(data) def __iter__(self): - return iter(self.it) + return iter(self.data) def __len__(self): - try: - return len(self.it) - except TypeError: - return sum(1 for _ in self.it) + return len(self.data) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 138c6a1c6732b..4afa82f4b2973 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -220,7 +220,7 @@ def __repr__(self): return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) -class FlattedValuesSerializer(BatchedSerializer): +class FlattenedValuesSerializer(BatchedSerializer): """ Serializes a stream of list of pairs, split the list of values @@ -240,7 +240,7 @@ def load_stream(self, stream): return self.serializer.load_stream(stream) def __repr__(self): - return "FlattedValuesSerializer(%d)" % self.batchSize + return "FlattenedValuesSerializer(%d)" % self.batchSize class AutoBatchedSerializer(BatchedSerializer): diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 62d76f96b91b8..45163d0fceb85 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -25,7 +25,7 @@ import random import pyspark.heapq3 as heapq -from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattedValuesSerializer, \ +from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer @@ -372,8 +372,10 @@ def iteritems(self): def _external_items(self): """ Return all partitioned items as iterator """ + assert not self.data if any(self.pdata): self._spill() + # disable partitioning and spilling when merge combiners from disk self.pdata = [] try: @@ -546,7 +548,10 @@ class ExternalList(object): def __init__(self, values): self.values = values - self.disk_count = 0 + if values and isinstance(values[0], list): + self.count = sum(len(i) for i in values) + else: + self.count = len(values) self._file = None self._ser = None @@ -555,16 +560,16 @@ def __getstate__(self): self._file.flush() f = os.fdopen(os.dup(self._file.fileno())) f.seek(0) - bytes = f.read() + serialized = f.read() else: - bytes = '' - return self.values, self.disk_count, bytes + serialized = '' + return self.values, self.count, serialized def __setstate__(self, item): - self.values, self.disk_count, bytes = item - if bytes: + self.values, self.count, serialized = item + if serialized: self._open_file() - self._file.write(bytes) + self._file.write(serialized) else: self._file = None self._ser = None @@ -583,10 +588,11 @@ def __iter__(self): yield v def __len__(self): - return self.disk_count + len(self.values) + return self.count def append(self, value): self.values.append(value) + self.count += len(value) if isinstance(value, list) else 1 # dump them into disk if the key is huge if len(self.values) >= self.LIMIT: self._spill() @@ -610,7 +616,6 @@ def _spill(self): used_memory = get_used_memory() pos = self._file.tell() self._ser.dump_stream([self.values], self._file) - self.disk_count += len(self.values) self.values = [] gc.collect() DiskBytesSpilled += self._file.tell() - pos @@ -657,7 +662,10 @@ def __init__(self, iterators): self.iterators = iterators def __len__(self): - return sum(len(vs) for vs in self.iterators) + try: + return len(self.iterators) + except: + return sum(len(i) for i in self.iterators) def __iter__(self): return itertools.chain.from_iterable(self.iterators) @@ -702,10 +710,10 @@ class ExternalGroupBy(ExternalMerger): """ SORT_KEY_LIMIT = 1000 - def _flatted_serializer(self): + def flattened_serializer(self): assert isinstance(self.serializer, BatchedSerializer) ser = self.serializer - return FlattedValuesSerializer(ser, 20) + return FlattenedValuesSerializer(ser, 20) def _object_size(self, obj): return len(obj) @@ -734,7 +742,7 @@ def _spill(self): # sort them before dumping into disks self._sorted = len(self.data) < self.SORT_KEY_LIMIT if self._sorted: - self.serializer = self._flatted_serializer() + self.serializer = self.flattened_serializer() for k in sorted(self.data.keys()): h = self._partition(k) self.serializer.dump_stream([(k, self.data[k])], streams[h]) @@ -802,7 +810,7 @@ def load_partition(j): else: # Flatten the combined values, so it will not consume huge # memory during merging sort. - ser = self._flatted_serializer() + ser = self.flattened_serializer() sorter = ExternalSorter(self.memory_limit, ser) sorted_items = sorter.sorted(itertools.chain(*disk_items), key=operator.itemgetter(0)) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7f6a5e3bf3655..03fdebaf21291 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -741,9 +741,11 @@ def test_external_group_by_key(self): filtered = gkv.filter(lambda (k, vs): k == 1) self.assertEqual(1, filtered.count()) self.assertEqual([(1, N/3)], filtered.mapValues(len).collect()) + self.assertEqual([(N/3, N/3)], + filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) result = filtered.collect()[0][1] self.assertEqual(N/3, len(result)) - self.assertTrue(isinstance(result.it, shuffle.ChainedIterable)) + self.assertTrue(isinstance(result.data, shuffle.ChainedIterable)) def test_sort_on_empty_rdd(self): self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())