diff --git a/docs/api_core.rst b/docs/api_core.rst index 8497ae70..b35bcf9e 100644 --- a/docs/api_core.rst +++ b/docs/api_core.rst @@ -1895,6 +1895,113 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`, Analogous to :cpp:struct:`for_getter`, but for setters. +.. cpp:struct:: template call_policy + + Request that custom logic be inserted around each call to the + bound function, by calling ``Policy::precall(args, nargs, cleanup)`` before + Python-to-C++ argument conversion, and ``Policy::postcall(args, nargs, ret)`` + after C++-to-Python return value conversion. + + If multiple call policy annotations are provided for the same function, then + their precall and postcall hooks will both execute left-to-right according + to the order in which the annotations were specified when binding the + function. + + The :cpp:struct:`nb::call_guard\() ` annotation + should be preferred over ``call_policy`` unless the wrapper logic + depends on the function arguments or return value. + If both annotations are combined, then + :cpp:struct:`nb::call_guard\() ` always executes on + the "inside" (closest to the bound function, after argument + conversions and before return value conversion) regardless of its + position in the function annotations list. + + Your ``Policy`` class must define two static member functions: + + .. cpp:function:: static void precall(PyObject **args, size_t nargs, detail::cleanup_list *cleanup); + + A hook that will be invoked before calling the bound function. More + precisely, it is called after any :ref:`argument locks ` + have been obtained, but before the Python arguments are converted to C++ + objects for the function call. + + This hook may access or modify the function arguments using the + *args* array, which holds borrowed references in one-to-one + correspondence with the C++ arguments of the bound function. If + the bound function is a method, then ``args[0]`` is its *self* + argument. *nargs* is the number of function arguments. It is actually + passed as ``std::integral_constant()``, so you can + match on that type if you want to do compile-time checks with it. + + The *cleanup* list may be used as it is used in type casters, + to cause some Python object references to be released at some point + after the bound function completes. (If the bound function is part + of an overload set, the cleanup list isn't released until all overloads + have been tried.) + + ``precall()`` may choose to throw a C++ exception. If it does, + it will preempt execution of the bound function, and the + exception will be treated as if the bound function had thrown it. + + .. cpp:function:: static void postcall(PyObject **args, size_t nargs, handle ret); + + A hook that will be invoked after calling the bound function and + converting its return value to a Python object, but only if the + bound function returned normally. + + *args* stores the Python object arguments, with the same semantics + as in ``precall()``, except that arguments that participated in + implicit conversions will have had their ``args[i]`` pointer updated + to reflect the new Python object that the implicit conversion produced. + *nargs* is the number of arguments, passed as a ``std::integral_constant`` + in the same way as for ``precall()``. + + *ret* is the bound function's return value. If the bound function returned + normally but its C++ return value could not be converted to a Python + object, then ``postcall()`` will execute with *ret* set to null, + and the Python error indicator might or might not be set to explain why. + + If the bound function did not return normally -- either because its + Python object arguments couldn't be converted to the appropriate C++ + types, or because the C++ function threw an exception -- then + ``postcall()`` **will not execute**. If you need some cleanup logic to + run even in such cases, your ``precall()`` can add a capsule object to the + cleanup list; its destructor will run eventually, but with no promises + as to when. A :cpp:struct:`nb::call_guard ` might be a + better choice. + + ``postcall()`` may choose to throw a C++ exception. If it does, + the result of the wrapped function will be destroyed, + and the exception will be raised in its place, as if the bound function + had thrown it just before returning. + + Here is an example policy to demonstrate. + ``nb::call_policy>()`` behaves like + :cpp:class:`nb::keep_alive\<0, I\>() `, except that the + return value is a treated as a list of objects rather than a single one. + + .. code-block:: cpp + + template + struct returns_references_to { + static void precall(PyObject **, size_t, nb::detail::cleanup_list *) {} + + template + static void postcall(PyObject **args, + std::integral_constant, + nb::handle ret) { + static_assert(I > 0 && I < N, + "I in returns_references_to must be in the " + "range [1, number of C++ function arguments]"); + if (!nb::isinstance(ret)) { + throw std::runtime_error("return value should be a sequence"); + } + for (nb::handle nurse : ret) { + nb::detail::keep_alive(nurse.ptr(), args[I]); + } + } + }; + .. _class_binding_annotations: Class binding annotations diff --git a/docs/changelog.rst b/docs/changelog.rst index 5702a5d6..972a9455 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -15,6 +15,18 @@ case, both modules must use the same nanobind ABI version, or they will be isolated from each other. Releases that don't explicitly mention an ABI version below inherit that of the preceding release. +Version TBD (unreleased) +------------------------ + +- Added a function annotation :cpp:class:`nb::call_policy\() + ` which supports custom function wrapping logic, + calling ``Policy::precall()`` before the bound function and + ``Policy::postcall()`` after. This is a low-level interface intended + for advanced users. The precall and postcall hooks are able to + observe the Python objects forming the function arguments and return + value, and the precall hook can change the arguments. See the linked + documentation for more details, important caveats, and an example policy. + Version 2.2.0 (October 3, 2024) ------------------------------- diff --git a/docs/functions.rst b/docs/functions.rst index 8c364378..ce721aac 100644 --- a/docs/functions.rst +++ b/docs/functions.rst @@ -459,6 +459,13 @@ Multiple guards should be specified as :cpp:class:`nb::call_guard\ `. Construction occurs left to right, while destruction occurs in reverse. +If your wrapping needs are more complex than +:cpp:class:`nb::call_guard\() ` can handle, it is also +possible to define a custom "call policy", which can observe or modify the +Python object arguments and observe the return value. See the documentation of +:cpp:class:`nb::call_policy\ ` for details. + + .. _higher_order_adv: Higher-order functions diff --git a/include/nanobind/nb_attr.h b/include/nanobind/nb_attr.h index 0476824e..99f628b0 100644 --- a/include/nanobind/nb_attr.h +++ b/include/nanobind/nb_attr.h @@ -151,6 +151,8 @@ struct sig { struct is_getter { }; +template struct call_policy final {}; + NAMESPACE_BEGIN(literals) constexpr arg operator"" _a(const char *name, size_t) { return arg(name); } NAMESPACE_END(literals) @@ -186,8 +188,9 @@ enum class func_flags : uint32_t { return_ref = (1 << 15), /// Does this overload specify a custom function signature (for docstrings, typing) has_signature = (1 << 16), - /// Does this function have one or more nb::keep_alive() annotations? - has_keep_alive = (1 << 17) + /// Does this function potentially modify the elements of the PyObject*[] array + /// representing its arguments? (nb::keep_alive() or call_policy annotations) + can_mutate_args = (1 << 17) }; enum cast_flags : uint8_t { @@ -384,12 +387,17 @@ NB_INLINE void func_extra_apply(F &, call_guard, size_t &) {} template NB_INLINE void func_extra_apply(F &f, nanobind::keep_alive, size_t &) { - f.flags |= (uint32_t) func_flags::has_keep_alive; + f.flags |= (uint32_t) func_flags::can_mutate_args; +} + +template +NB_INLINE void func_extra_apply(F &f, call_policy, size_t &) { + f.flags |= (uint32_t) func_flags::can_mutate_args; } template struct func_extra_info { using call_guard = void; - static constexpr bool keep_alive = false; + static constexpr bool pre_post_hooks = false; static constexpr size_t nargs_locked = 0; }; @@ -397,7 +405,7 @@ template struct func_extra_info : func_extra_info { }; template -struct func_extra_info, Ts...> : func_extra_info { +struct func_extra_info, Ts...> : func_extra_info { static_assert(std::is_same_v::call_guard, void>, "call_guard<> can only be specified once!"); using call_guard = nanobind::call_guard; @@ -405,29 +413,59 @@ struct func_extra_info, Ts...> : func_extra_info struct func_extra_info, Ts...> : func_extra_info { - static constexpr bool keep_alive = true; + static constexpr bool pre_post_hooks = true; +}; + +template +struct func_extra_info, Ts...> : func_extra_info { + static constexpr bool pre_post_hooks = true; }; template -struct func_extra_info : func_extra_info { +struct func_extra_info : func_extra_info { static constexpr size_t nargs_locked = 1 + func_extra_info::nargs_locked; }; template -struct func_extra_info : func_extra_info { +struct func_extra_info : func_extra_info { static constexpr size_t nargs_locked = 1 + func_extra_info::nargs_locked; }; -template -NB_INLINE void process_keep_alive(PyObject **, PyObject *, T *) { } +NB_INLINE void process_precall(PyObject **, size_t, detail::cleanup_list *, void *) { } + +template +NB_INLINE void +process_precall(PyObject **args, std::integral_constant nargs, + detail::cleanup_list *cleanup, call_policy *) { + Policy::precall(args, nargs, cleanup); +} + +NB_INLINE void process_postcall(PyObject **, size_t, PyObject *, void *) { } -template +template NB_INLINE void -process_keep_alive(PyObject **args, PyObject *result, - nanobind::keep_alive *) { +process_postcall(PyObject **args, std::integral_constant, + PyObject *result, nanobind::keep_alive *) { + static_assert(Nurse != Patient, + "keep_alive with the same argument as both nurse and patient " + "doesn't make sense"); + static_assert(Nurse <= NArgs && Patient <= NArgs, + "keep_alive template parameters must be in the range " + "[0, number of C++ function arguments]"); keep_alive(Nurse == 0 ? result : args[Nurse - 1], Patient == 0 ? result : args[Patient - 1]); } +template +NB_INLINE void +process_postcall(PyObject **args, std::integral_constant nargs, + PyObject *result, call_policy *) { + // result_guard avoids leaking a reference to the return object + // if postcall throws an exception + object result_guard = steal(result); + Policy::postcall(args, nargs, handle(result)); + result_guard.release(); +} + NAMESPACE_END(detail) NAMESPACE_END(NB_NAMESPACE) diff --git a/include/nanobind/nb_func.h b/include/nanobind/nb_func.h index da420349..10eb3994 100644 --- a/include/nanobind/nb_func.h +++ b/include/nanobind/nb_func.h @@ -11,14 +11,14 @@ NAMESPACE_BEGIN(NB_NAMESPACE) NAMESPACE_BEGIN(detail) template -bool from_python_keep_alive(Caster &c, PyObject **args, uint8_t *args_flags, - cleanup_list *cleanup, size_t index) { +bool from_python_remember_conv(Caster &c, PyObject **args, uint8_t *args_flags, + cleanup_list *cleanup, size_t index) { size_t size_before = cleanup->size(); if (!c.from_python(args[index], args_flags[index], cleanup)) return false; // If an implicit conversion took place, update the 'args' array so that - // the keep_alive annotation can later process this change + // any keep_alive annotation or postcall hook can be aware of this change size_t size_after = cleanup->size(); if (size_after != size_before) args[index] = (*cleanup)[size_after - 1]; @@ -244,9 +244,11 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...), } #endif - if constexpr (Info::keep_alive) { - if ((!from_python_keep_alive(in.template get(), args, - args_flags, cleanup, Is) || ...)) + if constexpr (Info::pre_post_hooks) { + std::integral_constant nargs_c; + (process_precall(args, nargs_c, cleanup, (Extra *) nullptr), ...); + if ((!from_python_remember_conv(in.template get(), args, + args_flags, cleanup, Is) || ...)) return NB_NEXT_OVERLOAD; } else { if ((!in.template get().from_python(args[Is], args_flags[Is], @@ -276,8 +278,10 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...), #endif } - if constexpr (Info::keep_alive) - (process_keep_alive(args, result, (Extra *) nullptr), ...); + if constexpr (Info::pre_post_hooks) { + std::integral_constant nargs_c; + (process_postcall(args, nargs_c, result, (Extra *) nullptr), ...); + } return result; }; diff --git a/src/nb_func.cpp b/src/nb_func.cpp index 08c4eefa..21883a62 100644 --- a/src/nb_func.cpp +++ b/src/nb_func.cpp @@ -196,21 +196,21 @@ PyObject *nb_func_new(const void *in_) noexcept { func_data_prelim<0> *f = (func_data_prelim<0> *) in_; arg_data *args_in = std::launder((arg_data *) f->args); - bool has_scope = f->flags & (uint32_t) func_flags::has_scope, - has_name = f->flags & (uint32_t) func_flags::has_name, - has_args = f->flags & (uint32_t) func_flags::has_args, - has_var_args = f->flags & (uint32_t) func_flags::has_var_kwargs, - has_var_kwargs = f->flags & (uint32_t) func_flags::has_var_args, - has_keep_alive = f->flags & (uint32_t) func_flags::has_keep_alive, - has_doc = f->flags & (uint32_t) func_flags::has_doc, - has_signature = f->flags & (uint32_t) func_flags::has_signature, - is_implicit = f->flags & (uint32_t) func_flags::is_implicit, - is_method = f->flags & (uint32_t) func_flags::is_method, - return_ref = f->flags & (uint32_t) func_flags::return_ref, - is_constructor = false, - is_init = false, - is_new = false, - is_setstate = false; + bool has_scope = f->flags & (uint32_t) func_flags::has_scope, + has_name = f->flags & (uint32_t) func_flags::has_name, + has_args = f->flags & (uint32_t) func_flags::has_args, + has_var_args = f->flags & (uint32_t) func_flags::has_var_kwargs, + has_var_kwargs = f->flags & (uint32_t) func_flags::has_var_args, + can_mutate_args = f->flags & (uint32_t) func_flags::can_mutate_args, + has_doc = f->flags & (uint32_t) func_flags::has_doc, + has_signature = f->flags & (uint32_t) func_flags::has_signature, + is_implicit = f->flags & (uint32_t) func_flags::is_implicit, + is_method = f->flags & (uint32_t) func_flags::is_method, + return_ref = f->flags & (uint32_t) func_flags::return_ref, + is_constructor = false, + is_init = false, + is_new = false, + is_setstate = false; PyObject *name = nullptr; PyObject *func_prev = nullptr; @@ -292,7 +292,7 @@ PyObject *nb_func_new(const void *in_) noexcept { maybe_make_immortal((PyObject *) func); // Check if the complex dispatch loop is needed - bool complex_call = has_keep_alive || has_var_kwargs || has_var_args || + bool complex_call = can_mutate_args || has_var_kwargs || has_var_args || f->nargs >= NB_MAXARGS_SIMPLE; if (has_args) { diff --git a/tests/test_functions.cpp b/tests/test_functions.cpp index 22450d87..34e584e3 100644 --- a/tests/test_functions.cpp +++ b/tests/test_functions.cpp @@ -1,6 +1,9 @@ +#include + #include #include #include +#include namespace nb = nanobind; using namespace nb::literals; @@ -12,6 +15,63 @@ struct my_call_guard { ~my_call_guard() { call_guard_value = 2; } }; +struct example_policy { + static inline std::vector> calls; + static void precall(PyObject **args, size_t nargs, + nb::detail::cleanup_list *cleanup) { + PyObject* tup = PyTuple_New(nargs); + for (size_t i = 0; i < nargs; ++i) { + if (!PyUnicode_CheckExact(args[i])) { + Py_DECREF(tup); + throw std::runtime_error("expected only strings"); + } + if (0 == PyUnicode_CompareWithASCIIString(args[i], "swapfrom")) { + nb::object replacement = nb::cast("swapto"); + args[i] = replacement.ptr(); + cleanup->append(replacement.release().ptr()); + } + Py_INCREF(args[i]); + PyTuple_SetItem(tup, i, args[i]); + } + calls.emplace_back(nb::steal(tup), nb::cast("")); + } + static void postcall(PyObject **args, size_t nargs, nb::handle ret) { + if (!ret.is_valid()) { + calls.back().second = nb::cast(""); + } else { + calls.back().second = nb::borrow(ret); + } + for (size_t i = 0; i < nargs; ++i) { + if (0 == PyUnicode_CompareWithASCIIString(args[i], "postthrow")) { + throw std::runtime_error("postcall exception"); + } + } + } +}; + +struct numeric_string { + unsigned long number; +}; + +template <> struct nb::detail::type_caster { + NB_TYPE_CASTER(numeric_string, const_name("str")) + + bool from_python(handle h, uint8_t flags, cleanup_list* cleanup) noexcept { + make_caster str_caster; + if (!str_caster.from_python(h, flags, cleanup)) + return false; + const char* str = str_caster.operator cast_t(); + if (!str) + return false; + char* endp; + value.number = strtoul(str, &endp, 10); + return *str && !*endp; + } + static handle from_cpp(numeric_string, rv_policy, handle) noexcept { + return nullptr; + } +}; + int test_31(int i) noexcept { return i; } NB_MODULE(test_functions_ext, m) { @@ -377,4 +437,23 @@ NB_MODULE(test_functions_ext, m) { m.def("test_bytearray_c_str", [](nb::bytearray o) -> const char * { return o.c_str(); }); m.def("test_bytearray_size", [](nb::bytearray o) { return o.size(); }); m.def("test_bytearray_resize", [](nb::bytearray c, int size) { return c.resize(size); }); + + // Test call_policy feature + m.def("test_call_policy", + [](const char* s, numeric_string n) -> const char* { + if (0 == strcmp(s, "returnfail")) { + return "not utf8 \xff"; + } + if (n.number > strlen(s)) { + throw std::runtime_error("offset too large"); + } + return s + n.number; + }, + nb::call_policy()); + + m.def("call_policy_record", + []() { + auto ret = std::move(example_policy::calls); + return ret; + }); } diff --git a/tests/test_functions.py b/tests/test_functions.py index 48824c6e..377ab86e 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -650,3 +650,78 @@ def test49_resize(): assert len(o) == 4 t.test_bytearray_resize(o, 8) assert len(o) == 8 + + +def test50_call_policy(): + def case(arg1, arg2, expect_ret): # type: (str, str, str | None) -> str + if hasattr(sys, "getrefcount"): + refs_before = (sys.getrefcount(arg1), sys.getrefcount(arg2)) + + ret = None + try: + ret = t.test_call_policy(arg1, arg2) + assert ret == expect_ret + return ret + finally: + if expect_ret is None: + assert t.call_policy_record() == [] + else: + (((arg1r, arg2r), recorded_ret),) = t.call_policy_record() + assert recorded_ret == expect_ret + assert ret is None or ret is recorded_ret + assert recorded_ret is not expect_ret + + if hasattr(sys, "getrefcount"): + # Make sure no reference leak occurred: should be + # one in getrefcount args, one or two in locals, + # zero or one in the pending-return-value slot. + # We have to decompose this to avoid getting confused + # by transient additional references added by pytest's + # assertion rewriting. + ret_refs = sys.getrefcount(recorded_ret) + assert ret_refs == 2 + 2 * (ret is not None) + + for (passed, recorded) in ((arg1, arg1r), (arg2, arg2r)): + if passed == "swapfrom": + assert recorded == "swapto" + if hasattr(sys, "getrefcount"): + recorded_refs = sys.getrefcount(recorded) + # recorded, arg1r, unnamed tuple, getrefcount arg + assert recorded_refs == 4 + else: + assert passed is recorded + + del passed, recorded, arg1r, arg2r + if hasattr(sys, "getrefcount"): + refs_after = (sys.getrefcount(arg1), sys.getrefcount(arg2)) + assert refs_before == refs_after + + # precall throws exception + with pytest.raises(RuntimeError, match="expected only strings"): + case(12345, "0", None) + + # conversion of args fails + with pytest.raises(TypeError): + case("string", "xxx", "") + + # function throws exception + with pytest.raises(RuntimeError, match="offset too large"): + case("abc", "4", "") + + # conversion of return value fails + with pytest.raises(UnicodeDecodeError): + case("returnfail", "4", "") + + # postcall throws exception + with pytest.raises(RuntimeError, match="postcall exception"): + case("postthrow", "4", "throw") + + # normal call + case("example", "1", "xample") + + # precall modifies args + case("swapfrom", "0", "swapto") + with pytest.raises(TypeError): + case("swapfrom", "xxx", "") + with pytest.raises(RuntimeError, match="offset too large"): + case("swapfrom", "10", "") diff --git a/tests/test_functions_ext.pyi.ref b/tests/test_functions_ext.pyi.ref index 5bfab347..05d5fe50 100644 --- a/tests/test_functions_ext.pyi.ref +++ b/tests/test_functions_ext.pyi.ref @@ -5,6 +5,8 @@ from typing import Annotated, Any, overload def call_guard_value() -> int: ... +def call_policy_record() -> list[tuple[tuple, object]]: ... + def hash_it(arg: object, /) -> int: ... def identity_i16(arg: int, /) -> int: ... @@ -178,6 +180,8 @@ def test_call_guard() -> int: ... def test_call_guard_wrapper_rvalue_ref(arg: int, /) -> int: ... +def test_call_policy(arg0: str, arg1: str, /) -> str: ... + def test_cast_char(arg: object, /) -> str: ... def test_cast_str(arg: object, /) -> str: ...