From fe909c26f29f954de7648f8bb7fbd1916ee34cc6 Mon Sep 17 00:00:00 2001 From: Graham Dumpleton Date: Fri, 23 Jun 2023 12:11:43 +1000 Subject: [PATCH] It was not possible to update __class__ attribute via the ObjectProxy when using C implementation. --- src/wrapt/_wrappers.c | 15 ++++++++++++++- tests/test_object_proxy.py | 18 ++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/wrapt/_wrappers.c b/src/wrapt/_wrappers.c index 7ff1085a..cf355311 100644 --- a/src/wrapt/_wrappers.c +++ b/src/wrapt/_wrappers.c @@ -1464,6 +1464,19 @@ static PyObject *WraptObjectProxy_get_class( /* ------------------------------------------------------------------------- */ +static int WraptObjectProxy_set_class(WraptObjectProxyObject *self, + PyObject *value) +{ + if (!self->wrapped) { + PyErr_SetString(PyExc_ValueError, "wrapper has not been initialized"); + return -1; + } + + return PyObject_SetAttrString(self->wrapped, "__class__", value); +} + +/* ------------------------------------------------------------------------- */ + static PyObject *WraptObjectProxy_get_annotations( WraptObjectProxyObject *self) { @@ -1779,7 +1792,7 @@ static PyGetSetDef WraptObjectProxy_getset[] = { { "__doc__", (getter)WraptObjectProxy_get_doc, (setter)WraptObjectProxy_set_doc, 0 }, { "__class__", (getter)WraptObjectProxy_get_class, - NULL, 0 }, + (setter)WraptObjectProxy_set_class, 0 }, { "__annotations__", (getter)WraptObjectProxy_get_annotations, (setter)WraptObjectProxy_set_annotations, 0 }, { "__wrapped__", (getter)WraptObjectProxy_get_wrapped, diff --git a/tests/test_object_proxy.py b/tests/test_object_proxy.py index 50522b27..352c66b4 100644 --- a/tests/test_object_proxy.py +++ b/tests/test_object_proxy.py @@ -2417,5 +2417,23 @@ def function(_self, self, *args, **kwargs): self.assertEqual(result, ('self', (), dict(arg1='arg1'))) +class TestOverridingSpecialAttributes(unittest.TestCase): + + def test_overriding_class_attribute(self): + class Object1: pass + class Object2(Object1): pass + + o1 = Object1() + + self.assertEqual(o1.__class__, type(o1)) + + o2 = Object2() + + self.assertEqual(o2.__class__, type(o2)) + + o2.__class__ = type(o1) + + self.assertEqual(o2.__class__, type(o1)) + if __name__ == '__main__': unittest.main()