Skip to content

Commit

Permalink
[3.12] gh-105332: [Enum] Fix unpickling flags in edge-cases (GH-105348)…
Browse files Browse the repository at this point in the history
… (GH-105520)

* revert enum pickling from by-name to by-value

(cherry picked from commit 4ff5690)

Co-authored-by: Nikita Sobolev <mail@sobolevn.me>
Co-authored-by: Ethan Furman <ethan@stoneleaf.us>
  • Loading branch information
3 people authored Jun 9, 2023
1 parent 68eeab7 commit 2f4a2d6
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 23 deletions.
11 changes: 10 additions & 1 deletion Doc/howto/enum.rst
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,16 @@ from that module.
nested in other classes.

It is possible to modify how enum members are pickled/unpickled by defining
:meth:`__reduce_ex__` in the enumeration class.
:meth:`__reduce_ex__` in the enumeration class. The default method is by-value,
but enums with complicated values may want to use by-name::

>>> class MyEnum(Enum):
... __reduce_ex__ = enum.pickle_by_enum_name

.. note::

Using by-name for flags is not recommended, as unnamed aliases will
not unpickle.


Functional API
Expand Down
30 changes: 9 additions & 21 deletions Lib/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'FlagBoundary', 'STRICT', 'CONFORM', 'EJECT', 'KEEP',
'global_flag_repr', 'global_enum_repr', 'global_str', 'global_enum',
'EnumCheck', 'CONTINUOUS', 'NAMED_FLAGS', 'UNIQUE',
'pickle_by_global_name', 'pickle_by_enum_name',
]


Expand Down Expand Up @@ -922,7 +923,6 @@ def _convert_(cls, name, module, filter, source=None, *, boundary=None, as_globa
body['__module__'] = module
tmp_cls = type(name, (object, ), body)
cls = _simple_enum(etype=cls, boundary=boundary or KEEP)(tmp_cls)
cls.__reduce_ex__ = _reduce_ex_by_global_name
if as_global:
global_enum(cls)
else:
Expand Down Expand Up @@ -1240,7 +1240,7 @@ def __hash__(self):
return hash(self._name_)

def __reduce_ex__(self, proto):
return getattr, (self.__class__, self._name_)
return self.__class__, (self._value_, )

# enum.property is used to provide access to the `name` and
# `value` attributes of enum members while keeping some measure of
Expand Down Expand Up @@ -1307,8 +1307,14 @@ def _generate_next_value_(name, start, count, last_values):
return name.lower()


def _reduce_ex_by_global_name(self, proto):
def pickle_by_global_name(self, proto):
# should not be used with Flag-type enums
return self.name
_reduce_ex_by_global_name = pickle_by_global_name

def pickle_by_enum_name(self, proto):
# should not be used with Flag-type enums
return getattr, (self.__class__, self._name_)

class FlagBoundary(StrEnum):
"""
Expand All @@ -1330,23 +1336,6 @@ class Flag(Enum, boundary=STRICT):
Support for flags
"""

def __reduce_ex__(self, proto):
cls = self.__class__
unknown = self._value_ & ~cls._flag_mask_
member_value = self._value_ & cls._flag_mask_
if unknown and member_value:
return _or_, (cls(member_value), unknown)
for val in _iter_bits_lsb(member_value):
rest = member_value & ~val
if rest:
return _or_, (cls(rest), cls._value2member_map_.get(val))
else:
break
if self._name_ is None:
return cls, (self._value_,)
else:
return getattr, (cls, self._name_)

_numeric_repr_ = repr

@staticmethod
Expand Down Expand Up @@ -2073,7 +2062,6 @@ def _old_convert_(etype, name, module, filter, source=None, *, boundary=None):
# unless some values aren't comparable, in which case sort by name
members.sort(key=lambda t: t[0])
cls = etype(name, members, module=module, boundary=boundary or KEEP)
cls.__reduce_ex__ = _reduce_ex_by_global_name
return cls

_stdlib_enums = IntEnum, StrEnum, IntFlag
28 changes: 27 additions & 1 deletion Lib/test/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def load_tests(loader, tests, ignore):
'../../Doc/library/enum.rst',
optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE,
))
if os.path.exists('Doc/howto/enum.rst'):
tests.addTests(doctest.DocFileSuite(
'../../Doc/howto/enum.rst',
optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE,
))
return tests

MODULE = __name__
Expand Down Expand Up @@ -66,6 +71,7 @@ class FlagStooges(Flag):
LARRY = 1
CURLY = 2
MOE = 4
BIG = 389
except Exception as exc:
FlagStooges = exc

Expand All @@ -74,17 +80,20 @@ class FlagStoogesWithZero(Flag):
LARRY = 1
CURLY = 2
MOE = 4
BIG = 389

class IntFlagStooges(IntFlag):
LARRY = 1
CURLY = 2
MOE = 4
BIG = 389

class IntFlagStoogesWithZero(IntFlag):
NOFLAG = 0
LARRY = 1
CURLY = 2
MOE = 4
BIG = 389

# for pickle test and subclass tests
class Name(StrEnum):
Expand Down Expand Up @@ -1942,14 +1951,17 @@ class NEI(NamedInt, Enum):
__qualname__ = 'NEI'
x = ('the-x', 1)
y = ('the-y', 2)

self.assertIs(NEI.__new__, Enum.__new__)
self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)")
globals()['NamedInt'] = NamedInt
globals()['NEI'] = NEI
NI5 = NamedInt('test', 5)
self.assertEqual(NI5, 5)
self.assertEqual(NEI.y.value, 2)
with self.assertRaisesRegex(TypeError, "name and value must be specified"):
test_pickle_dump_load(self.assertIs, NEI.y)
# fix pickle support and try again
NEI.__reduce_ex__ = enum.pickle_by_enum_name
test_pickle_dump_load(self.assertIs, NEI.y)
test_pickle_dump_load(self.assertIs, NEI)

Expand Down Expand Up @@ -3252,11 +3264,17 @@ def test_pickle(self):
test_pickle_dump_load(self.assertEqual,
FlagStooges.CURLY&~FlagStooges.CURLY)
test_pickle_dump_load(self.assertIs, FlagStooges)
test_pickle_dump_load(self.assertEqual, FlagStooges.BIG)
test_pickle_dump_load(self.assertEqual,
FlagStooges.CURLY|FlagStooges.BIG)

test_pickle_dump_load(self.assertIs, FlagStoogesWithZero.CURLY)
test_pickle_dump_load(self.assertEqual,
FlagStoogesWithZero.CURLY|FlagStoogesWithZero.MOE)
test_pickle_dump_load(self.assertIs, FlagStoogesWithZero.NOFLAG)
test_pickle_dump_load(self.assertEqual, FlagStoogesWithZero.BIG)
test_pickle_dump_load(self.assertEqual,
FlagStoogesWithZero.CURLY|FlagStoogesWithZero.BIG)

test_pickle_dump_load(self.assertIs, IntFlagStooges.CURLY)
test_pickle_dump_load(self.assertEqual,
Expand All @@ -3266,11 +3284,19 @@ def test_pickle(self):
test_pickle_dump_load(self.assertEqual, IntFlagStooges(0))
test_pickle_dump_load(self.assertEqual, IntFlagStooges(0x30))
test_pickle_dump_load(self.assertIs, IntFlagStooges)
test_pickle_dump_load(self.assertEqual, IntFlagStooges.BIG)
test_pickle_dump_load(self.assertEqual, IntFlagStooges.BIG|1)
test_pickle_dump_load(self.assertEqual,
IntFlagStooges.CURLY|IntFlagStooges.BIG)

test_pickle_dump_load(self.assertIs, IntFlagStoogesWithZero.CURLY)
test_pickle_dump_load(self.assertEqual,
IntFlagStoogesWithZero.CURLY|IntFlagStoogesWithZero.MOE)
test_pickle_dump_load(self.assertIs, IntFlagStoogesWithZero.NOFLAG)
test_pickle_dump_load(self.assertEqual, IntFlagStoogesWithZero.BIG)
test_pickle_dump_load(self.assertEqual, IntFlagStoogesWithZero.BIG|1)
test_pickle_dump_load(self.assertEqual,
IntFlagStoogesWithZero.CURLY|IntFlagStoogesWithZero.BIG)

def test_contains_tf(self):
Open = self.Open
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Revert pickling method from by-name back to by-value.

0 comments on commit 2f4a2d6

Please sign in to comment.