diff --git a/cinderx/StaticPython/classloader.c b/cinderx/StaticPython/classloader.c index 161a94ba89c..224c1450d27 100644 --- a/cinderx/StaticPython/classloader.c +++ b/cinderx/StaticPython/classloader.c @@ -2807,8 +2807,13 @@ update_thunk(_Py_StaticThunk *thunk, PyObject *previous, PyObject *new_value) { Py_CLEAR(thunk->thunk_tcs.tcs_value); if (new_value != NULL) { - thunk->thunk_tcs.tcs_value = new_value; - Py_INCREF(new_value); + PyObject *unwrapped_new = classloader_maybe_unwrap_callable(new_value); + if (unwrapped_new != NULL) { + thunk->thunk_tcs.tcs_value = unwrapped_new; + } else { + thunk->thunk_tcs.tcs_value = new_value; + Py_INCREF(new_value); + } } PyObject *funcref; if (new_value == previous) { @@ -4258,6 +4263,13 @@ _PyClassLoader_ResolveFunction(PyObject *path, PyObject **container) original = NULL; } + if (original != NULL) { + PyObject *res = (PyObject *)get_or_make_thunk(func, original, *container, containerkey); + Py_DECREF(func); + assert(res != NULL); + return res; + } + if (func != NULL) { if (Py_TYPE(func) == &PyStaticMethod_Type) { PyObject *res = Ci_PyStaticMethod_GetFunc(func); @@ -4273,13 +4285,6 @@ _PyClassLoader_ResolveFunction(PyObject *path, PyObject **container) } } - if (original != NULL) { - PyObject *res = (PyObject *)get_or_make_thunk(func, original, *container, containerkey); - Py_DECREF(func); - assert(res != NULL); - return res; - } - return func; } diff --git a/cinderx/test_cinderx/test_compiler/test_static/patch.py b/cinderx/test_cinderx/test_compiler/test_static/patch.py index 74ce7c28657..76ddfe54caa 100644 --- a/cinderx/test_cinderx/test_compiler/test_static/patch.py +++ b/cinderx/test_cinderx/test_compiler/test_static/patch.py @@ -257,6 +257,28 @@ def g(): with patch(f"{mod.__name__}.C.f", autospec=True, return_value=100) as p: self.assertEqual(g(), 100) + def test_patch_staticmethod_with_staticmethod(self): + codestr = """ + class C: + @staticmethod + def f(): + return 42 + + def g(): + return C.f() + """ + with self.in_module(codestr) as mod: + g = mod.g + for i in range(100): + self.assertEqual(g(), 42) + + @staticmethod + def new(): + return 100 + + mod.C.f = new + self.assertEqual(g(), 100) + def test_patch_static_function_non_autospec(self): codestr = """ class C: