Skip to content

Commit

Permalink
Add support for custom call policies
Browse files Browse the repository at this point in the history
  • Loading branch information
oremanj committed Oct 24, 2024
1 parent fd22b8c commit e343e35
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 37 deletions.
107 changes: 107 additions & 0 deletions docs/api_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1895,6 +1895,113 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`,

Analogous to :cpp:struct:`for_getter`, but for setters.

.. cpp:struct:: template <typename Policy> call_policy

Request that custom logic be inserted around each call to the
bound function, by calling ``Policy::precall(args, nargs, cleanup)`` before
Python-to-C++ argument conversion, and ``Policy::postcall(args, nargs, ret)``
after C++-to-Python return value conversion.

If multiple call policy annotations are provided for the same function, then
their precall and postcall hooks will both execute left-to-right according
to the order in which the annotations were specified when binding the
function.

The :cpp:struct:`nb::call_guard\<T\>() <call_guard>` annotation
should be preferred over ``call_policy`` unless the wrapper logic
depends on the function arguments or return value.
If both annotations are combined, then
:cpp:struct:`nb::call_guard\<T\>() <call_guard>` always executes on
the "inside" (closest to the bound function, after argument
conversions and before return value conversion) regardless of its
position in the function annotations list.

Your ``Policy`` class must define two static member functions:

.. cpp:function:: static void precall(PyObject **args, size_t nargs, detail::cleanup_list *cleanup);

A hook that will be invoked before calling the bound function. More
precisely, it is called after any :ref:`argument locks <argument-locks>`
have been obtained, but before the Python arguments are converted to C++
objects for the function call.

This hook may access or modify the function arguments using the
*args* array, which holds borrowed references in one-to-one
correspondence with the C++ arguments of the bound function. If
the bound function is a method, then ``args[0]`` is its *self*
argument. *nargs* is the number of function arguments. It is actually
passed as ``std::integral_constant<size_t, N>()``, so you can
match on that type if you want to do compile-time checks with it.

The *cleanup* list may be used as it is used in type casters,
to cause some Python object references to be released at some point
after the bound function completes. (If the bound function is part
of an overload set, the cleanup list isn't released until all overloads
have been tried.)

``precall()`` may choose to throw a C++ exception. If it does,
it will preempt execution of the bound function, and the
exception will be treated as if the bound function had thrown it.

.. cpp:function:: static void postcall(PyObject **args, size_t nargs, handle ret);

A hook that will be invoked after calling the bound function and
converting its return value to a Python object, but only if the
bound function returned normally.

*args* stores the Python object arguments, with the same semantics
as in ``precall()``, except that arguments that participated in
implicit conversions will have had their ``args[i]`` pointer updated
to reflect the new Python object that the implicit conversion produced.
*nargs* is the number of arguments, passed as a ``std::integral_constant``
in the same way as for ``precall()``.

*ret* is the bound function's return value. If the bound function returned
normally but its C++ return value could not be converted to a Python
object, then ``postcall()`` will execute with *ret* set to null,
and the Python error indicator might or might not be set to explain why.

If the bound function did not return normally -- either because its
Python object arguments couldn't be converted to the appropriate C++
types, or because the C++ function threw an exception -- then
``postcall()`` **will not execute**. If you need some cleanup logic to
run even in such cases, your ``precall()`` can add a capsule object to the
cleanup list; its destructor will run eventually, but with no promises
as to when. A :cpp:struct:`nb::call_guard <call_guard>` might be a
better choice.

``postcall()`` may choose to throw a C++ exception. If it does,
the result of the wrapped function will be destroyed,
and the exception will be raised in its place, as if the bound function
had thrown it just before returning.

Here is an example policy to demonstrate.
``nb::call_policy<returns_references_to<I>>()`` behaves like
:cpp:class:`nb::keep_alive\<0, I\>() <keep_alive>`, except that the
return value is a treated as a list of objects rather than a single one.

.. code-block:: cpp
template <size_t I>
struct returns_references_to {
static void precall(PyObject **, size_t, nb::detail::cleanup_list *) {}
template <size_t N>
static void postcall(PyObject **args,
std::integral_constant<size_t, N>,
nb::handle ret) {
static_assert(I > 0 && I < N,
"I in returns_references_to<I> must be in the "
"range [1, number of C++ function arguments]");
if (!nb::isinstance<nb::sequence>(ret)) {
throw std::runtime_error("return value should be a sequence");
}
for (nb::handle nurse : ret) {
nb::detail::keep_alive(nurse.ptr(), args[I]);
}
}
};
.. _class_binding_annotations:

Class binding annotations
Expand Down
12 changes: 12 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ case, both modules must use the same nanobind ABI version, or they will be
isolated from each other. Releases that don't explicitly mention an ABI version
below inherit that of the preceding release.

Version TBD (unreleased)
------------------------

- Added a function annotation :cpp:class:`nb::call_policy\<Policy\>()
<call_policy>` which supports custom function wrapping logic,
calling ``Policy::precall()`` before the bound function and
``Policy::postcall()`` after. This is a low-level interface intended
for advanced users. The precall and postcall hooks are able to
observe the Python objects forming the function arguments and return
value, and the precall hook can change the arguments. See the linked
documentation for more details, important caveats, and an example policy.

Version 2.2.0 (October 3, 2024)
-------------------------------

Expand Down
7 changes: 7 additions & 0 deletions docs/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,13 @@ Multiple guards should be specified as :cpp:class:`nb::call_guard\<T1, T2,
T3...\> <call_guard>`. Construction occurs left to right, while destruction
occurs in reverse.

If your wrapping needs are more complex than
:cpp:class:`nb::call_guard\<T\>() <call_guard>` can handle, it is also
possible to define a custom "call policy", which can observe or modify the
Python object arguments and observe the return value. See the documentation of
:cpp:class:`nb::call_policy\<Policy\> <call_policy>` for details.


.. _higher_order_adv:

Higher-order functions
Expand Down
64 changes: 51 additions & 13 deletions include/nanobind/nb_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ struct sig {

struct is_getter { };

template <typename Policy> struct call_policy final {};

NAMESPACE_BEGIN(literals)
constexpr arg operator"" _a(const char *name, size_t) { return arg(name); }
NAMESPACE_END(literals)
Expand Down Expand Up @@ -186,8 +188,9 @@ enum class func_flags : uint32_t {
return_ref = (1 << 15),
/// Does this overload specify a custom function signature (for docstrings, typing)
has_signature = (1 << 16),
/// Does this function have one or more nb::keep_alive() annotations?
has_keep_alive = (1 << 17)
/// Does this function potentially modify the elements of the PyObject*[] array
/// representing its arguments? (nb::keep_alive() or call_policy annotations)
can_mutate_args = (1 << 17)
};

enum cast_flags : uint8_t {
Expand Down Expand Up @@ -384,50 +387,85 @@ NB_INLINE void func_extra_apply(F &, call_guard<Ts...>, size_t &) {}

template <typename F, size_t Nurse, size_t Patient>
NB_INLINE void func_extra_apply(F &f, nanobind::keep_alive<Nurse, Patient>, size_t &) {
f.flags |= (uint32_t) func_flags::has_keep_alive;
f.flags |= (uint32_t) func_flags::can_mutate_args;
}

template <typename F, typename Policy>
NB_INLINE void func_extra_apply(F &f, call_policy<Policy>, size_t &) {
f.flags |= (uint32_t) func_flags::can_mutate_args;
}

template <typename... Ts> struct func_extra_info {
using call_guard = void;
static constexpr bool keep_alive = false;
static constexpr bool pre_post_hooks = false;
static constexpr size_t nargs_locked = 0;
};

template <typename T, typename... Ts> struct func_extra_info<T, Ts...>
: func_extra_info<Ts...> { };

template <typename... Cs, typename... Ts>
struct func_extra_info<nanobind::call_guard<Cs...>, Ts...> : func_extra_info<Ts...> {
struct func_extra_info<call_guard<Cs...>, Ts...> : func_extra_info<Ts...> {
static_assert(std::is_same_v<typename func_extra_info<Ts...>::call_guard, void>,
"call_guard<> can only be specified once!");
using call_guard = nanobind::call_guard<Cs...>;
};

template <size_t Nurse, size_t Patient, typename... Ts>
struct func_extra_info<nanobind::keep_alive<Nurse, Patient>, Ts...> : func_extra_info<Ts...> {
static constexpr bool keep_alive = true;
static constexpr bool pre_post_hooks = true;
};

template <typename Policy, typename... Ts>
struct func_extra_info<call_policy<Policy>, Ts...> : func_extra_info<Ts...> {
static constexpr bool pre_post_hooks = true;
};

template <typename... Ts>
struct func_extra_info<nanobind::arg_locked, Ts...> : func_extra_info<Ts...> {
struct func_extra_info<arg_locked, Ts...> : func_extra_info<Ts...> {
static constexpr size_t nargs_locked = 1 + func_extra_info<Ts...>::nargs_locked;
};

template <typename... Ts>
struct func_extra_info<nanobind::lock_self, Ts...> : func_extra_info<Ts...> {
struct func_extra_info<lock_self, Ts...> : func_extra_info<Ts...> {
static constexpr size_t nargs_locked = 1 + func_extra_info<Ts...>::nargs_locked;
};

template <typename T>
NB_INLINE void process_keep_alive(PyObject **, PyObject *, T *) { }
NB_INLINE void process_precall(PyObject **, size_t, detail::cleanup_list *, void *) { }

template <size_t NArgs, typename Policy>
NB_INLINE void
process_precall(PyObject **args, std::integral_constant<size_t, NArgs> nargs,
detail::cleanup_list *cleanup, call_policy<Policy> *) {
Policy::precall(args, nargs, cleanup);
}

NB_INLINE void process_postcall(PyObject **, size_t, PyObject *, void *) { }

template <size_t Nurse, size_t Patient>
template <size_t NArgs, size_t Nurse, size_t Patient>
NB_INLINE void
process_keep_alive(PyObject **args, PyObject *result,
nanobind::keep_alive<Nurse, Patient> *) {
process_postcall(PyObject **args, std::integral_constant<size_t, NArgs>,
PyObject *result, nanobind::keep_alive<Nurse, Patient> *) {
static_assert(Nurse != Patient,
"keep_alive with the same argument as both nurse and patient "
"doesn't make sense");
static_assert(Nurse <= NArgs && Patient <= NArgs,
"keep_alive template parameters must be in the range "
"[0, number of C++ function arguments]");
keep_alive(Nurse == 0 ? result : args[Nurse - 1],
Patient == 0 ? result : args[Patient - 1]);
}

template <size_t NArgs, typename Policy>
NB_INLINE void
process_postcall(PyObject **args, std::integral_constant<size_t, NArgs> nargs,
PyObject *result, call_policy<Policy> *) {
// result_guard avoids leaking a reference to the return object
// if postcall throws an exception
object result_guard = steal(result);
Policy::postcall(args, nargs, handle(result));
result_guard.release();
}

NAMESPACE_END(detail)
NAMESPACE_END(NB_NAMESPACE)
20 changes: 12 additions & 8 deletions include/nanobind/nb_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

template <typename Caster>
bool from_python_keep_alive(Caster &c, PyObject **args, uint8_t *args_flags,
cleanup_list *cleanup, size_t index) {
bool from_python_remember_conv(Caster &c, PyObject **args, uint8_t *args_flags,
cleanup_list *cleanup, size_t index) {
size_t size_before = cleanup->size();
if (!c.from_python(args[index], args_flags[index], cleanup))
return false;

// If an implicit conversion took place, update the 'args' array so that
// the keep_alive annotation can later process this change
// any keep_alive annotation or postcall hook can be aware of this change
size_t size_after = cleanup->size();
if (size_after != size_before)
args[index] = (*cleanup)[size_after - 1];
Expand Down Expand Up @@ -244,9 +244,11 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),
}
#endif

if constexpr (Info::keep_alive) {
if ((!from_python_keep_alive(in.template get<Is>(), args,
args_flags, cleanup, Is) || ...))
if constexpr (Info::pre_post_hooks) {
std::integral_constant<size_t, nargs> nargs_c;
(process_precall(args, nargs_c, cleanup, (Extra *) nullptr), ...);
if ((!from_python_remember_conv(in.template get<Is>(), args,
args_flags, cleanup, Is) || ...))
return NB_NEXT_OVERLOAD;
} else {
if ((!in.template get<Is>().from_python(args[Is], args_flags[Is],
Expand Down Expand Up @@ -276,8 +278,10 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),
#endif
}

if constexpr (Info::keep_alive)
(process_keep_alive(args, result, (Extra *) nullptr), ...);
if constexpr (Info::pre_post_hooks) {
std::integral_constant<size_t, nargs> nargs_c;
(process_postcall(args, nargs_c, result, (Extra *) nullptr), ...);
}

return result;
};
Expand Down
32 changes: 16 additions & 16 deletions src/nb_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,21 +196,21 @@ PyObject *nb_func_new(const void *in_) noexcept {
func_data_prelim<0> *f = (func_data_prelim<0> *) in_;
arg_data *args_in = std::launder((arg_data *) f->args);

bool has_scope = f->flags & (uint32_t) func_flags::has_scope,
has_name = f->flags & (uint32_t) func_flags::has_name,
has_args = f->flags & (uint32_t) func_flags::has_args,
has_var_args = f->flags & (uint32_t) func_flags::has_var_kwargs,
has_var_kwargs = f->flags & (uint32_t) func_flags::has_var_args,
has_keep_alive = f->flags & (uint32_t) func_flags::has_keep_alive,
has_doc = f->flags & (uint32_t) func_flags::has_doc,
has_signature = f->flags & (uint32_t) func_flags::has_signature,
is_implicit = f->flags & (uint32_t) func_flags::is_implicit,
is_method = f->flags & (uint32_t) func_flags::is_method,
return_ref = f->flags & (uint32_t) func_flags::return_ref,
is_constructor = false,
is_init = false,
is_new = false,
is_setstate = false;
bool has_scope = f->flags & (uint32_t) func_flags::has_scope,
has_name = f->flags & (uint32_t) func_flags::has_name,
has_args = f->flags & (uint32_t) func_flags::has_args,
has_var_args = f->flags & (uint32_t) func_flags::has_var_kwargs,
has_var_kwargs = f->flags & (uint32_t) func_flags::has_var_args,
can_mutate_args = f->flags & (uint32_t) func_flags::can_mutate_args,
has_doc = f->flags & (uint32_t) func_flags::has_doc,
has_signature = f->flags & (uint32_t) func_flags::has_signature,
is_implicit = f->flags & (uint32_t) func_flags::is_implicit,
is_method = f->flags & (uint32_t) func_flags::is_method,
return_ref = f->flags & (uint32_t) func_flags::return_ref,
is_constructor = false,
is_init = false,
is_new = false,
is_setstate = false;

PyObject *name = nullptr;
PyObject *func_prev = nullptr;
Expand Down Expand Up @@ -292,7 +292,7 @@ PyObject *nb_func_new(const void *in_) noexcept {
maybe_make_immortal((PyObject *) func);

// Check if the complex dispatch loop is needed
bool complex_call = has_keep_alive || has_var_kwargs || has_var_args ||
bool complex_call = can_mutate_args || has_var_kwargs || has_var_args ||
f->nargs >= NB_MAXARGS_SIMPLE;

if (has_args) {
Expand Down
Loading

0 comments on commit e343e35

Please sign in to comment.