Skip to content

Commit

Permalink
Set the module of #[pyfunction]s.
Browse files Browse the repository at this point in the history
Previously neither the module nor the name of the module of
pyfunctions were registered. This commit passes the module and
its name when creating a new pyfunction.

PyModule::add_function and PyModule::add_module have been added and are
set to replace `add_wrapped` in a future release. `add_wrapped` is kept
for compatibility reasons during the transition.

Depending on whether a `PyModule` or `Python` is the argument for the
Python function-wrapper, the module will be registered with the function.
  • Loading branch information
sebpuetz committed Sep 3, 2020
1 parent 21ad52a commit 4e3a35f
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 38 deletions.
28 changes: 14 additions & 14 deletions examples/rustapi_module/src/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,29 +215,29 @@ impl TzClass {

#[pymodule]
fn datetime(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(make_date))?;
m.add_wrapped(wrap_pyfunction!(get_date_tuple))?;
m.add_wrapped(wrap_pyfunction!(date_from_timestamp))?;
m.add_wrapped(wrap_pyfunction!(make_time))?;
m.add_wrapped(wrap_pyfunction!(get_time_tuple))?;
m.add_wrapped(wrap_pyfunction!(make_delta))?;
m.add_wrapped(wrap_pyfunction!(get_delta_tuple))?;
m.add_wrapped(wrap_pyfunction!(make_datetime))?;
m.add_wrapped(wrap_pyfunction!(get_datetime_tuple))?;
m.add_wrapped(wrap_pyfunction!(datetime_from_timestamp))?;
m.add_function(wrap_pyfunction!(make_date))?;
m.add_function(wrap_pyfunction!(get_date_tuple))?;
m.add_function(wrap_pyfunction!(date_from_timestamp))?;
m.add_function(wrap_pyfunction!(make_time))?;
m.add_function(wrap_pyfunction!(get_time_tuple))?;
m.add_function(wrap_pyfunction!(make_delta))?;
m.add_function(wrap_pyfunction!(get_delta_tuple))?;
m.add_function(wrap_pyfunction!(make_datetime))?;
m.add_function(wrap_pyfunction!(get_datetime_tuple))?;
m.add_function(wrap_pyfunction!(datetime_from_timestamp))?;

// Python 3.6+ functions
#[cfg(Py_3_6)]
{
m.add_wrapped(wrap_pyfunction!(time_with_fold))?;
m.add_function(wrap_pyfunction!(time_with_fold))?;
#[cfg(not(PyPy))]
{
m.add_wrapped(wrap_pyfunction!(get_time_tuple_fold))?;
m.add_wrapped(wrap_pyfunction!(get_datetime_tuple_fold))?;
m.add_function(wrap_pyfunction!(get_time_tuple_fold))?;
m.add_function(wrap_pyfunction!(get_datetime_tuple_fold))?;
}
}

m.add_wrapped(wrap_pyfunction!(issue_219))?;
m.add_function(wrap_pyfunction!(issue_219))?;
m.add_class::<TzClass>()?;

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion examples/rustapi_module/src/othermod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn double(x: i32) -> i32 {

#[pymodule]
fn othermod(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double))?;
m.add_function(wrap_pyfunction!(double))?;

m.add_class::<ModClass>()?;

Expand Down
6 changes: 3 additions & 3 deletions examples/word-count/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ fn count_line(line: &str, needle: &str) -> usize {

#[pymodule]
fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(search))?;
m.add_wrapped(wrap_pyfunction!(search_sequential))?;
m.add_wrapped(wrap_pyfunction!(search_sequential_allow_threads))?;
m.add_function(wrap_pyfunction!(search))?;
m.add_function(wrap_pyfunction!(search_sequential))?;
m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?;

Ok(())
}
4 changes: 2 additions & 2 deletions guide/src/function.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fn double(x: usize) -> usize {

#[pymodule]
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double)).unwrap();
m.add_function(wrap_pyfunction!(double)).unwrap();

Ok(())
}
Expand Down Expand Up @@ -65,7 +65,7 @@ fn num_kwds(kwds: Option<&PyDict>) -> usize {

#[pymodule]
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(num_kwds)).unwrap();
m.add_function(wrap_pyfunction!(num_kwds)).unwrap();
Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions guide/src/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ fn subfunction() -> String {

#[pymodule]
fn submodule(_py: Python, module: &PyModule) -> PyResult<()> {
module.add_wrapped(wrap_pyfunction!(subfunction))?;
module.add_function(wrap_pyfunction!(subfunction))?;
Ok(())
}

#[pymodule]
fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
module.add_wrapped(wrap_pymodule!(submodule))?;
module.add_module(wrap_pymodule!(submodule))?;
Ok(())
}

Expand Down
31 changes: 27 additions & 4 deletions pyo3-derive-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
let item: syn::ItemFn = syn::parse_quote! {
fn block_wrapper() {
#function_to_python
#module_name.add_wrapped(&#function_wrapper_ident)?;
#module_name.add_function(&#function_wrapper_ident)?;
}
};
stmts.extend(item.block.stmts.into_iter());
Expand Down Expand Up @@ -193,7 +193,17 @@ pub fn add_fn_to_module(
let wrapper = function_c_wrapper(&func.sig.ident, &spec);

Ok(quote! {
fn #function_wrapper_ident(py: pyo3::Python) -> pyo3::PyObject {
fn #function_wrapper_ident<'a>(
args: impl Into<pyo3::derive_utils::WrapPyFunctionArguments<'a>>
) -> pyo3::PyObject {
let arg = args.into();
let (py, maybe_module) = match arg {
pyo3::derive_utils::WrapPyFunctionArguments::Python(py) => (py, None),
pyo3::derive_utils::WrapPyFunctionArguments::PyModule(module) => {
let py = <pyo3::types::PyModule as pyo3::PyNativeType>::py(module);
(py, Some(module))
}
};
#wrapper

let _def = pyo3::class::PyMethodDef {
Expand All @@ -203,12 +213,25 @@ pub fn add_fn_to_module(
ml_doc: #doc,
};

let (mod_ptr, name) = if let Some(m) = maybe_module {
let mod_ptr = <pyo3::types::PyModule as ::pyo3::conversion::AsPyPointer>::as_ptr(m);
let name = unsafe { pyo3::ffi::PyModule_GetNameObject(mod_ptr) };
if name.is_null() {
let err = PyErr::fetch(py);
return <PyErr as pyo3::conversion::IntoPy<PyObject>>::into_py(err, py);
}
(mod_ptr, name)
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
};

let function = unsafe {
pyo3::PyObject::from_owned_ptr(
py,
pyo3::ffi::PyCFunction_New(
pyo3::ffi::PyCFunction_NewEx(
Box::into_raw(Box::new(_def.as_method_def())),
::std::ptr::null_mut()
mod_ptr,
name
)
)
};
Expand Down
19 changes: 19 additions & 0 deletions src/derive_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,22 @@ where
<R as std::convert::TryFrom<&'a PyCell<T>>>::try_from(cell)
}
}

/// Enum to abstract over the arguments of Python function wrappers.
#[doc(hidden)]
pub enum WrapPyFunctionArguments<'a> {
Python(Python<'a>),
PyModule(&'a PyModule),
}

impl<'a> From<Python<'a>> for WrapPyFunctionArguments<'a> {
fn from(py: Python<'a>) -> WrapPyFunctionArguments<'a> {
WrapPyFunctionArguments::Python(py)
}
}

impl<'a> From<&'a PyModule> for WrapPyFunctionArguments<'a> {
fn from(module: &'a PyModule) -> WrapPyFunctionArguments<'a> {
WrapPyFunctionArguments::PyModule(module)
}
}
1 change: 1 addition & 0 deletions src/ffi/moduleobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ extern "C" {
pub fn PyModule_New(name: *const c_char) -> *mut PyObject;
#[cfg_attr(PyPy, link_name = "PyPyModule_GetDict")]
pub fn PyModule_GetDict(arg1: *mut PyObject) -> *mut PyObject;
#[cfg_attr(PyPy, link_name = "PyPyModule_GetNameObject")]
pub fn PyModule_GetNameObject(arg1: *mut PyObject) -> *mut PyObject;
#[cfg_attr(PyPy, link_name = "PyPyModule_GetName")]
pub fn PyModule_GetName(arg1: *mut PyObject) -> *const c_char;
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
//! #[pymodule]
//! /// A Python module implemented in Rust.
//! fn string_sum(py: Python, m: &PyModule) -> PyResult<()> {
//! m.add_wrapped(wrap_pyfunction!(sum_as_string))?;
//! m.add_function(wrap_pyfunction!(sum_as_string))?;
//!
//! Ok(())
//! }
Expand Down
2 changes: 1 addition & 1 deletion src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl<'p> Python<'p> {
/// let gil = Python::acquire_gil();
/// let py = gil.python();
/// let m = PyModule::new(py, "pcount").unwrap();
/// m.add_wrapped(wrap_pyfunction!(parallel_count)).unwrap();
/// m.add_function(wrap_pyfunction!(parallel_count)).unwrap();
/// let locals = [("pcount", m)].into_py_dict(py);
/// py.run(r#"
/// s = ["Flow", "my", "tears", "the", "Policeman", "Said"]
Expand Down
41 changes: 40 additions & 1 deletion src/types/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,50 @@ impl PyModule {
/// ```rust,ignore
/// m.add("also_double", wrap_pyfunction!(double)(py));
/// ```
pub fn add_wrapped(&self, wrapper: &impl Fn(Python) -> PyObject) -> PyResult<()> {
///
/// **This function will be deprecated in the next release. Please use the specific
/// [add_function] and [add_module] functions instead.**
pub fn add_wrapped<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> {
let function = wrapper(self.py());
let name = function
.getattr(self.py(), "__name__")
.expect("A function or module must have a __name__");
self.add(name.extract(self.py()).unwrap(), function)
}

/// Adds a (sub)module to a module.
///
/// Use this together with `#[pymodule]` and [wrap_pymodule!].
///
/// ```rust,ignore
/// m.add_module(wrap_pymodule!(utils));
/// ```
pub fn add_module<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> {
let function = wrapper(self.py());
let name = function
.getattr(self.py(), "__name__")
.expect("A module must have a __name__");
self.add(name.extract(self.py()).unwrap(), function)
}

/// Adds a function to a module, using the functions __name__ as name.
///
/// Use this together with the`#[pyfunction]` and [wrap_pyfunction!].
///
/// ```rust,ignore
/// m.add_function(wrap_pyfunction!(double));
/// ```
///
/// You can also add a function with a custom name using [add](PyModule::add):
///
/// ```rust,ignore
/// m.add("also_double", wrap_pyfunction!(double)(py, m));
/// ```
pub fn add_function<'a>(&'a self, wrapper: &impl Fn(&'a Self) -> PyObject) -> PyResult<()> {
let function = wrapper(self);
let name = function
.getattr(self.py(), "__name__")
.expect("A function or module must have a __name__");
self.add(name.extract(self.py()).unwrap(), function)
}
}
18 changes: 9 additions & 9 deletions tests/test_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn double(x: usize) -> usize {

/// This module is implemented in Rust.
#[pymodule]
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> {
use pyo3::wrap_pyfunction;

#[pyfn(m, "sum_as_string")]
Expand All @@ -60,8 +60,8 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {

m.add("foo", "bar").unwrap();

m.add_wrapped(wrap_pyfunction!(double)).unwrap();
m.add("also_double", wrap_pyfunction!(double)(py)).unwrap();
m.add_function(wrap_pyfunction!(double)).unwrap();
m.add("also_double", wrap_pyfunction!(double)(m)).unwrap();

Ok(())
}
Expand Down Expand Up @@ -157,7 +157,7 @@ fn r#move() -> usize {
fn raw_ident_module(_py: Python, module: &PyModule) -> PyResult<()> {
use pyo3::wrap_pyfunction;

module.add_wrapped(wrap_pyfunction!(r#move))
module.add_function(wrap_pyfunction!(r#move))
}

#[test]
Expand All @@ -182,7 +182,7 @@ fn custom_named_fn() -> usize {
fn foobar_module(_py: Python, m: &PyModule) -> PyResult<()> {
use pyo3::wrap_pyfunction;

m.add_wrapped(wrap_pyfunction!(custom_named_fn))?;
m.add_function(wrap_pyfunction!(custom_named_fn))?;
m.dict().set_item("yay", "me")?;
Ok(())
}
Expand Down Expand Up @@ -216,7 +216,7 @@ fn subfunction() -> String {
fn submodule(_py: Python, module: &PyModule) -> PyResult<()> {
use pyo3::wrap_pyfunction;

module.add_wrapped(wrap_pyfunction!(subfunction))?;
module.add_function(wrap_pyfunction!(subfunction))?;
Ok(())
}

Expand All @@ -229,8 +229,8 @@ fn superfunction() -> String {
fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
use pyo3::{wrap_pyfunction, wrap_pymodule};

module.add_wrapped(wrap_pyfunction!(superfunction))?;
module.add_wrapped(wrap_pymodule!(submodule))?;
module.add_function(wrap_pyfunction!(superfunction))?;
module.add_module(wrap_pymodule!(submodule))?;
Ok(())
}

Expand Down Expand Up @@ -268,7 +268,7 @@ fn vararg_module(_py: Python, m: &PyModule) -> PyResult<()> {
ext_vararg_fn(py, a, vararg)
}

m.add_wrapped(pyo3::wrap_pyfunction!(ext_vararg_fn))
m.add_function(pyo3::wrap_pyfunction!(ext_vararg_fn))
.unwrap();
Ok(())
}
Expand Down

0 comments on commit 4e3a35f

Please sign in to comment.