diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 965de1b01df2ee..04ad409a1bece8 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,8 @@ Changes: - Converts now work with frozen classes. `#76 `_ +- Pickling now works with ``__slots__`` classes. + `#81 `_ ---- diff --git a/docs/examples.rst b/docs/examples.rst index d8de3555ee7d78..4ca2a797f755a1 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -441,6 +441,13 @@ Slot classes are a little different than ordinary, dictionary-backed classes: - Since non-slot classes cannot be turned into slot classes after they have been created, ``attr.s(.., slots=True)`` will *replace* the class it is applied to with a copy. In almost all cases this isn't a problem, but we mention it for the sake of completeness. +- Using :mod:`pickle` with slot classes requires pickle protocol 2 or greater. + Python 2 uses protocol 0 by default so the protocol needs to be specified. + Python 3 uses protocol 3 by default. + You can support protocol 0 and 1 by implementing :meth:`__getstate__ ` and :meth:`__setstate__ ` methods yourself. + Those methods are created for frozen slot classes because they won't pickle otherwise. + `Think twice `_ before using :mod:`pickle` though. + All in all, setting ``slots=True`` is usually a very good idea. diff --git a/src/attr/_make.py b/src/attr/_make.py index 764fe552fa038d..d2f144ed3f3722 100644 --- a/src/attr/_make.py +++ b/src/attr/_make.py @@ -233,6 +233,9 @@ def wrap(cls): cls = _add_init(cls, frozen) if frozen is True: cls.__setattr__ = _frozen_setattrs + if slots is True: + # slots and frozen require __getstate__/__setstate__ to work + cls = _add_pickle(cls) if slots is True: cls_dict = dict(cls.__dict__) cls_dict["__slots__"] = tuple(ca_list) @@ -433,6 +436,29 @@ def _add_init(cls, frozen): return cls +def _add_pickle(cls): + """ + Add pickle helpers, needed for frozen and slotted classes + """ + def _slots_getstate__(obj): + """ + Play nice with pickle. + """ + return tuple(getattr(obj, a.name) for a in fields(obj.__class__)) + + def _slots_setstate__(obj, state): + """ + Play nice with pickle. + """ + __bound_setattr = _obj_setattr.__get__(obj, Attribute) + for a, value in zip(fields(obj.__class__), state): + __bound_setattr(a.name, value) + + cls.__getstate__ = _slots_getstate__ + cls.__setstate__ = _slots_setstate__ + return cls + + def fields(cls): """ Returns the tuple of ``attrs`` attributes for a class. @@ -630,6 +656,20 @@ def from_counting_attr(cls, name, ca): in Attribute.__slots__ if k != "name")) + # Don't use _add_pickle since fields(Attribute) doesn't work + def __getstate__(self): + """ + Play nice with pickle. + """ + return tuple(getattr(self, name) for name in self.__slots__) + + def __setstate__(self, state): + """ + Play nice with pickle. + """ + __bound_setattr = _obj_setattr.__get__(self, Attribute) + for name, value in zip(self.__slots__, state): + __bound_setattr(name, value) _a = [Attribute(name=name, default=NOTHING, validator=None, repr=True, cmp=True, hash=True, init=True) diff --git a/tests/test_dark_magic.py b/tests/test_dark_magic.py index 5178e0b6f1f2ed..56c74f7aef4bdb 100644 --- a/tests/test_dark_magic.py +++ b/tests/test_dark_magic.py @@ -1,4 +1,5 @@ from __future__ import absolute_import, division, print_function +import pickle import pytest from hypothesis import given @@ -68,6 +69,11 @@ class Frozen(object): x = attr.ib() +@attr.s(frozen=True, slots=False) +class FrozenNoSlots(object): + x = attr.ib() + + class TestDarkMagic(object): """ Integration tests. @@ -177,3 +183,30 @@ def test_frozen_instance(self, frozen_class): assert e.value.args[0] == "can't set attribute" assert 1 == frozen.x + + @pytest.mark.parametrize("cls", + [C1, C1Slots, C2, C2Slots, Super, SuperSlots, + Sub, SubSlots, Frozen, FrozenNoSlots]) + @pytest.mark.parametrize("protocol", + range(2, pickle.HIGHEST_PROTOCOL + 1)) + def test_pickle_attributes(self, cls, protocol): + """ + Pickling/un-pickling of Attribute instances works. + """ + for attribute in attr.fields(cls): + assert attribute == pickle.loads(pickle.dumps(attribute, protocol)) + + @pytest.mark.parametrize("cls", + [C1, C1Slots, C2, C2Slots, Super, SuperSlots, + Sub, SubSlots, Frozen, FrozenNoSlots]) + @pytest.mark.parametrize("protocol", + range(2, pickle.HIGHEST_PROTOCOL + 1)) + def test_pickle_object(self, cls, protocol): + """ + Pickle object serialization works on all kinds of attrs classes. + """ + if len(attr.fields(cls)) == 2: + obj = cls(123, 456) + else: + obj = cls(123) + assert repr(obj) == repr(pickle.loads(pickle.dumps(obj, protocol)))