From abdab302dae0b37297af4dd316db207b977a2fc0 Mon Sep 17 00:00:00 2001 From: Bluenix Date: Mon, 6 Jun 2022 15:05:13 +0200 Subject: [PATCH] GH-93521: Filter out `__weakref__` slot if present in bases --- Lib/dataclasses.py | 13 +++++++++---- Lib/test/test_dataclasses.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 4645ebfa71e71c..18ab69053fd520 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1156,11 +1156,16 @@ def _add_slots(cls, is_frozen, weakref_slot): itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1])) ) # The slots for our class. Remove slots from our base classes. Add - # '__weakref__' if weakref_slot was given. + # '__weakref__' if weakref_slot was given, unless it is already present. cls_dict["__slots__"] = tuple( - itertools.chain( - itertools.filterfalse(inherited_slots.__contains__, field_names), - ("__weakref__",) if weakref_slot else ()) + itertools.filterfalse( + inherited_slots.__contains__, + itertools.chain( + # gh-93521: '__weakref__' also needs to be filtered out if + # already present in inherited_slots + field_names, ('__weakref__',) if weakref_slot else () + ) + ), ) for field_name in field_names: diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index cf29cd07516f0d..1f607478324330 100644 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -3109,6 +3109,22 @@ def test_weakref_slot_make_dataclass(self): "weakref_slot is True but slots is False"): B = make_dataclass('B', [('a', int),], weakref_slot=True) + def test_weakref_slot_subclass_also_weakref_slot(self): + @dataclass(slots=True, weakref_slot=True) + class Base: + field: int + + # A *can* also specify weakref_slot=True if it wants to (gh-93521) + @dataclass(slots=True, weakref_slot=True) + class A(Base): + ... + + # __weakref__ is in the base class, not A. But an instance of A + # is still weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) class TestDescriptors(unittest.TestCase): def test_set_name(self):