Skip to content

Commit

Permalink
[SPARK-1687] [PySpark] pickable namedtuple
Browse files Browse the repository at this point in the history
Add an hook to replace original namedtuple with an pickable one, then namedtuple could be used in RDDs.

PS: pyspark should be import BEFORE "from collections import namedtuple"

Author: Davies Liu <davies.liu@gmail.com>

Closes apache#1623 from davies/namedtuple and squashes the following commits:

045dad8 [Davies Liu] remove unrelated code changes
4132f32 [Davies Liu] address comment
55b1c1a [Davies Liu] fix tests
61f86eb [Davies Liu] replace all the reference of namedtuple to new hacked one
98df6c6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into namedtuple
f7b1bde [Davies Liu] add hack for CloudPickleSerializer
0c5c849 [Davies Liu] Merge branch 'master' of github.com:apache/spark into namedtuple
21991e6 [Davies Liu] hack namedtuple in __main__ module, make it picklable.
93b03b8 [Davies Liu] pickable namedtuple
  • Loading branch information
davies authored and JoshRosen committed Aug 4, 2014
1 parent e053c55 commit 59f84a9
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
60 changes: 60 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@
import marshal
import struct
import sys
import types
import collections

from pyspark import cloudpickle


Expand Down Expand Up @@ -267,6 +270,63 @@ def dumps(self, obj):
return obj


# Hook namedtuple, make it picklable

__cls = {}


def _restore(name, fields, value):
""" Restore an object of namedtuple"""
k = (name, fields)
cls = __cls.get(k)
if cls is None:
cls = collections.namedtuple(name, fields)
__cls[k] = cls
return cls(*value)


def _hack_namedtuple(cls):
""" Make class generated by namedtuple picklable """
name = cls.__name__
fields = cls._fields
def __reduce__(self):
return (_restore, (name, fields, tuple(self)))
cls.__reduce__ = __reduce__
return cls


def _hijack_namedtuple():
""" Hack namedtuple() to make it picklable """
global _old_namedtuple # or it will put in closure

def _copy_func(f):
return types.FunctionType(f.func_code, f.func_globals, f.func_name,
f.func_defaults, f.func_closure)

_old_namedtuple = _copy_func(collections.namedtuple)

def namedtuple(name, fields, verbose=False, rename=False):
cls = _old_namedtuple(name, fields, verbose, rename)
return _hack_namedtuple(cls)

# replace namedtuple with new one
collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple
collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple
collections.namedtuple.func_code = namedtuple.func_code

# hack the cls already generated by namedtuple
# those created in other module can be pickled as normal,
# so only hack those in __main__ module
for n, o in sys.modules["__main__"].__dict__.iteritems():
if (type(o) is type and o.__base__ is tuple
and hasattr(o, "_fields")
and "__reduce__" not in o.__dict__):
_hack_namedtuple(o) # hack inplace


_hijack_namedtuple()


class PickleSerializer(FramedSerializer):
"""
Serializes objects using Python's cPickle serializer:
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ def test_huge_dataset(self):
m._cleanup()


class SerializationTestCase(unittest.TestCase):

def test_namedtuple(self):
from collections import namedtuple
from cPickle import dumps, loads
P = namedtuple("P", "x y")
p1 = P(1, 3)
p2 = loads(dumps(p1, 2))
self.assertEquals(p1, p2)


class PySparkTestCase(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -298,6 +309,14 @@ def test_itemgetter(self):
self.assertEqual([1], rdd.map(itemgetter(1)).collect())
self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect())

def test_namedtuple_in_rdd(self):
from collections import namedtuple
Person = namedtuple("Person", "id firstName lastName")
jon = Person(1, "Jon", "Doe")
jane = Person(2, "Jane", "Doe")
theDoes = self.sc.parallelize([jon, jane])
self.assertEquals([jon, jane], theDoes.collect())


class TestIO(PySparkTestCase):

Expand Down

0 comments on commit 59f84a9

Please sign in to comment.