Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom __getstate__, __setstate__ for slotted classes #513

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,16 +593,19 @@ def _create_slots_class(self):
if qualname is not None:
cd["__qualname__"] = qualname

# __weakref__ is not writable.
state_attr_names = tuple(
an for an in self._attr_names if an != "__weakref__"
)
def get_state_attr_names(self):
# __weakref__ is not writable.
return (
a.name for a in self.__attrs_attrs__ if a.name != "__weakref__"
)

def slots_getstate(self):
"""
Automatically created by attrs.
"""
return tuple(getattr(self, name) for name in state_attr_names)
return tuple(
getattr(self, name) for name in get_state_attr_names(self)
)

hash_caching_enabled = self._cache_hash

Expand All @@ -611,7 +614,7 @@ def slots_setstate(self, state):
Automatically created by attrs.
"""
__bound_setattr = _obj_setattr.__get__(self, Attribute)
for name, value in zip(state_attr_names, state):
for name, value in zip(get_state_attr_names(self), state):
__bound_setattr(name, value)
# Clearing the hash code cache on deserialization is needed
# because hash codes can change from run to run. See issue
Expand All @@ -622,8 +625,20 @@ def slots_setstate(self, state):
__bound_setattr(_hash_cache_field, None)

# slots and frozen require __getstate__/__setstate__ to work
cd["__getstate__"] = slots_getstate
cd["__setstate__"] = slots_setstate
if hasattr(self._cls, "__getstate__"):
if not hasattr(self._cls, "__setstate__"):
raise ValueError(
"__setstate__ must be implemented when __getstate__ is."
)
else:
# TODO look closer! Some classes (like Exception) implements
andhus marked this conversation as resolved.
Show resolved Hide resolved
# __setstate__ only together with custom __reduce__
# if hasattr(self._cls, "__setstate__"):
# raise ValueError(
# "__getstate__ must be implemented when __setstate__ is."
# )
cd["__getstate__"] = slots_getstate
cd["__setstate__"] = slots_setstate

# Create new class based on old class and our methods.
cls = type(self._cls)(self._cls.__name__, self._cls.__bases__, cd)
Expand Down
143 changes: 142 additions & 1 deletion tests/test_slots.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Unit tests for slot-related functionality.
"""

import pickle
import weakref

import pytest
Expand Down Expand Up @@ -529,3 +529,144 @@ class C(object):
w = weakref.ref(c)

assert c is w()


def _loads_dumps(instance):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO test all protocols

"""
Pickle loads-dumps round trip test helper.
"""
return pickle.loads(pickle.dumps(instance))


def test_getstate_setstate_auto_generated():
"""
__getstate__ and __setstate__ are auto generated (if not implemented).
"""
c1 = C1Slots(x=10, y=2)
assert (10, 2) == c1.__getstate__()

c1.__setstate__((1, 0))
assert C1Slots(x=1, y=0) == c1

c1_new = _loads_dumps(C1Slots(x=11, y=22))
assert C1Slots(x=11, y=22) == c1_new


@attr.s(slots=True, hash=True)
class SubC1Slots(C1Slots):
z = attr.ib()


@attr.s(slots=True, hash=True)
class SubC1(C1):
z = attr.ib()


def test_subclass_getstate_setstate_auto_generated():
"""
Autogenerated __getstate__ and __setstate__ works for subclasses.
"""
for subclass in [SubC1, SubC1Slots]:
sc1 = subclass(x=1, y=2, z="z")
assert (1, 2, "z") == sc1.__getstate__()

sc1.__setstate__((10, 0, "zz"))
assert subclass(x=10, y=0, z="zz") == sc1

sc1_new = _loads_dumps(subclass(x=11, y=22, z="zz"))
assert subclass(x=11, y=22, z="zz") == sc1_new


@attr.s(slots=True)
class CustomGetSetState(object):
a = attr.ib()
b = attr.ib()

def __getstate__(self):
"""Modify value for `b` just to test that this method is called."""
return self.a, self.b + "_placeholder"

def __setstate__(self, state):
"""Modify value for `b` just to test that this method is called."""
self.a, self.b = state
self.b += "_reconstructed"


@attr.s(slots=True)
class SubCustomGetSetState(CustomGetSetState):
c = attr.ib()

def __getstate__(self):
return super(SubCustomGetSetState, self).__getstate__() + tuple(self.c)

def __setstate__(self, state):
super(SubCustomGetSetState, self).__setstate__(state[:2])
self.c = state[-1]


@attr.s(slots=True)
class SubCustomGetSetStateNoOverride(CustomGetSetState):
c = attr.ib()


def test_custom_getstate_setstate_effective():
"""
Custom __getstate__ and __setstate__ are used when implemented.
"""
c = CustomGetSetState("a", "b")
assert ("a", "b_placeholder") == c.__getstate__()

c.__setstate__(("other_a", "b"))
assert CustomGetSetState("other_a", "b_reconstructed") == c

c_new = _loads_dumps(CustomGetSetState("a", "b"))
assert CustomGetSetState("a", "b_placeholder_reconstructed") == c_new


def test_subclass_custom_getstate_setstate_referencing_super():
"""
Custom __getstate__ and __setstate__ referencing super class ok.
"""
sc = SubCustomGetSetState("a", "b", "c")
assert ("a", "b_placeholder", "c") == sc.__getstate__()

sc.__setstate__(("a", "b", "other_c"))
assert SubCustomGetSetState("a", "b_reconstructed", "other_c") == sc

sc_new = _loads_dumps(SubCustomGetSetState("a", "b", "c"))
assert (
SubCustomGetSetState("a", "b_placeholder_reconstructed", "c") == sc_new
)


def test_subclass_custom_getstate_setstate_no_override():
"""
No overriding of custom __getstate__ and __setstate__ -> superclass
implementation.
"""
sc = SubCustomGetSetStateNoOverride("a", "b", "c")
assert ("a", "b_placeholder") == sc.__getstate__()

sc.__setstate__(("a", "bb"))
assert SubCustomGetSetStateNoOverride("a", "bb_reconstructed", "c") == sc

sc_new = _loads_dumps(SubCustomGetSetStateNoOverride("a", "b", "c"))
with pytest.raises(AttributeError):
sc_new.c

assert "a" == sc_new.a
assert "b_placeholder_reconstructed" == sc_new.b


def test_raise_if_getstate_and_not_setstate_implemented():
"""
ValueError raised if only one of __getstate__, __setstate__ implemented.
"""
with pytest.raises(ValueError):

@attr.s(slots=True)
class CustomGetStateOnly(object):
a = attr.ib()

def __getstate__(self):
return self.a