From e821d9b5256968c804f4d6eac9a426cf4c6cadb7 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Fri, 8 Sep 2023 16:26:43 -0700 Subject: [PATCH] Enable `"factory sets error and returns nullptr"` --- include/pybind11/detail/init.h | 3 +++ tests/test_factory_constructors.cpp | 10 ++++++++++ tests/test_factory_constructors.py | 7 +++++++ 3 files changed, 20 insertions(+) diff --git a/include/pybind11/detail/init.h b/include/pybind11/detail/init.h index a28680e2..c0ad7d05 100644 --- a/include/pybind11/detail/init.h +++ b/include/pybind11/detail/init.h @@ -39,6 +39,9 @@ PYBIND11_NAMESPACE_BEGIN(initimpl) inline void no_nullptr(void *ptr) { if (!ptr) { + if (PyErr_Occurred()) { + throw error_already_set(); + } throw type_error("pybind11::init(): factory function returned nullptr"); } } diff --git a/tests/test_factory_constructors.cpp b/tests/test_factory_constructors.cpp index aefe52a8..9fa0bd52 100644 --- a/tests/test_factory_constructors.cpp +++ b/tests/test_factory_constructors.cpp @@ -419,6 +419,16 @@ TEST_SUBMODULE(factory_constructors, m) { "__init__", [](NoisyAlloc &a, int i, const std::string &) { new (&a) NoisyAlloc(i); }); }); + struct FactoryErrorAlreadySet {}; + py::class_(m, "FactoryErrorAlreadySet") + .def(py::init([](bool set_error) -> FactoryErrorAlreadySet * { + if (!set_error) { + return new FactoryErrorAlreadySet(); + } + py::set_error(PyExc_ValueError, "factory sets error and returns nullptr"); + return nullptr; + })); + // static_assert testing (the following def's should all fail with appropriate compilation // errors): #if 0 diff --git a/tests/test_factory_constructors.py b/tests/test_factory_constructors.py index ca084d64..f59d6cf2 100644 --- a/tests/test_factory_constructors.py +++ b/tests/test_factory_constructors.py @@ -515,3 +515,10 @@ def __init__(self, bad): str(excinfo.value) == "__init__(self, ...) called with invalid or missing `self` argument" ) + + +def test_factory_error_already_set(): + obj = m.FactoryErrorAlreadySet(False) + assert isinstance(obj, m.FactoryErrorAlreadySet) + with pytest.raises(ValueError, match="factory sets error and returns nullptr"): + m.FactoryErrorAlreadySet(True)