Skip to content

Commit

Permalink
Enable &Self in #[pymethods] again
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Jul 28, 2020
1 parent b05eb48 commit f5f2e84
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Fix segfault with #[pyclass(dict, unsendable)] [#1058](https://github.com/PyO3/pyo3/pull/1058)
- Don't rely on the order of structmembers to compute offsets in PyCell. Related to
[#1058](https://github.com/PyO3/pyo3/pull/1058). [#1059](https://github.com/PyO3/pyo3/pull/1059)
- Allows `&Self` as a `#[pymethods]` argument again. [#1071](https://github.com/PyO3/pyo3/pull/1071)

## [0.11.1] - 2020-06-30
### Added
Expand Down
2 changes: 1 addition & 1 deletion pyo3-derive-backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ mod utils;
pub use module::{add_fn_to_module, process_functions_in_module, py_init};
pub use pyclass::{build_py_class, PyClassArgs};
pub use pyfunction::{build_py_function, PyFunctionAttr};
pub use pyimpl::{build_py_methods, impl_methods};
pub use pyimpl::build_py_methods;
pub use pyproto::build_py_proto;
pub use utils::get_doc;
2 changes: 1 addition & 1 deletion pyo3-derive-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream {
#name(#(#names),*)
};

let body = pymethod::impl_arg_params(spec, cb);
let body = pymethod::impl_arg_params(spec, None, cb);

quote! {
unsafe extern "C" fn __wrap(
Expand Down
47 changes: 37 additions & 10 deletions pyo3-derive-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ fn impl_wrap_common(
}
}
} else {
let body = impl_arg_params(&spec, body);
let body = impl_arg_params(&spec, Some(cls), body);

quote! {
unsafe extern "C" fn __wrap(
Expand All @@ -138,7 +138,7 @@ fn impl_wrap_common(
pub fn impl_proto_wrap(cls: &syn::Type, spec: &FnSpec<'_>, self_ty: &SelfType) -> TokenStream {
let python_name = &spec.python_name;
let cb = impl_call(cls, &spec);
let body = impl_arg_params(&spec, cb);
let body = impl_arg_params(&spec, Some(cls), cb);
let slf = self_ty.receiver(cls);

quote! {
Expand Down Expand Up @@ -166,7 +166,7 @@ pub fn impl_wrap_new(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
let python_name = &spec.python_name;
let names: Vec<syn::Ident> = get_arg_names(&spec);
let cb = quote! { #cls::#name(#(#names),*) };
let body = impl_arg_params(spec, cb);
let body = impl_arg_params(spec, Some(cls), cb);

quote! {
#[allow(unused_mut)]
Expand Down Expand Up @@ -198,7 +198,7 @@ pub fn impl_wrap_class(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
let names: Vec<syn::Ident> = get_arg_names(&spec);
let cb = quote! { #cls::#name(&_cls, #(#names),*) };

let body = impl_arg_params(spec, cb);
let body = impl_arg_params(spec, Some(cls), cb);

quote! {
#[allow(unused_mut)]
Expand Down Expand Up @@ -226,7 +226,7 @@ pub fn impl_wrap_static(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
let names: Vec<syn::Ident> = get_arg_names(&spec);
let cb = quote! { #cls::#name(#(#names),*) };

let body = impl_arg_params(spec, cb);
let body = impl_arg_params(spec, Some(cls), cb);

quote! {
#[allow(unused_mut)]
Expand Down Expand Up @@ -383,7 +383,11 @@ fn impl_call(cls: &syn::Type, spec: &FnSpec<'_>) -> TokenStream {
quote! { #cls::#fname(_slf, #(#names),*) }
}

pub fn impl_arg_params(spec: &FnSpec<'_>, body: TokenStream) -> TokenStream {
pub fn impl_arg_params(
spec: &FnSpec<'_>,
self_: Option<&syn::Type>,
body: TokenStream,
) -> TokenStream {
if spec.args.is_empty() {
return quote! {
#body
Expand Down Expand Up @@ -412,7 +416,7 @@ pub fn impl_arg_params(spec: &FnSpec<'_>, body: TokenStream) -> TokenStream {
let mut param_conversion = Vec::new();
let mut option_pos = 0;
for (idx, arg) in spec.args.iter().enumerate() {
param_conversion.push(impl_arg_param(&arg, &spec, idx, &mut option_pos));
param_conversion.push(impl_arg_param(&arg, &spec, idx, self_, &mut option_pos));
}

let (mut accept_args, mut accept_kwargs) = (false, false);
Expand Down Expand Up @@ -458,6 +462,7 @@ fn impl_arg_param(
arg: &FnArg<'_>,
spec: &FnSpec<'_>,
idx: usize,
self_: Option<&syn::Type>,
option_pos: &mut usize,
) -> TokenStream {
let arg_name = syn::Ident::new(&format!("arg{}", idx), Span::call_site());
Expand Down Expand Up @@ -491,7 +496,7 @@ fn impl_arg_param(
quote! { None }
};
if let syn::Type::Reference(tref) = ty {
let (tref, mut_) = tref_preprocess(tref);
let (tref, mut_) = preprocess_tref(tref, self_);
// To support Rustc 1.39.0, we don't use as_deref here...
let tmp_as_deref = if mut_.is_some() {
quote! { _tmp.as_mut().map(std::ops::DerefMut::deref_mut) }
Expand Down Expand Up @@ -524,7 +529,7 @@ fn impl_arg_param(
};
}
} else if let syn::Type::Reference(tref) = arg.ty {
let (tref, mut_) = tref_preprocess(tref);
let (tref, mut_) = preprocess_tref(tref, self_);
// Get &T from PyRef<T>
quote! {
let #mut_ _tmp: <#tref as pyo3::derive_utils::ExtractExt>::Target
Expand All @@ -537,12 +542,34 @@ fn impl_arg_param(
}
};

fn tref_preprocess(tref: &syn::TypeReference) -> (syn::TypeReference, Option<syn::token::Mut>) {
/// Replace `Self`, remove lifetime and get mutability from the type
fn preprocess_tref(
tref: &syn::TypeReference,
self_: Option<&syn::Type>,
) -> (syn::TypeReference, Option<syn::token::Mut>) {
let mut tref = tref.to_owned();
if let Some(syn::Type::Path(tpath)) = self_ {
replace_self(&mut tref, &tpath.path);
}
tref.lifetime = None;
let mut_ = tref.mutability;
(tref, mut_)
}

/// Replace `Self` with the exact type name since it is used out of the impl block
fn replace_self(tref: &mut syn::TypeReference, self_path: &syn::Path) {
match &mut *tref.elem {
syn::Type::Reference(tref_inner) => replace_self(tref_inner, self_path),
syn::Type::Path(ref mut tpath) => {
if let Some(ident) = tpath.path.get_ident() {
if ident == "Self" {
tpath.path = self_path.to_owned();
}
}
}
_ => {}
}
}
}

pub fn impl_py_method_def(spec: &FnSpec, wrapper: &TokenStream) -> TokenStream {
Expand Down
12 changes: 8 additions & 4 deletions tests/test_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ impl InstanceMethod {
fn method(&self) -> PyResult<i32> {
Ok(self.member)
}

// Checks that &Self works
fn add_other(&self, other: &Self) -> i32 {
self.member + other.member
}
}

#[test]
Expand All @@ -26,10 +31,9 @@ fn instance_method() {
let obj = PyCell::new(py, InstanceMethod { member: 42 }).unwrap();
let obj_ref = obj.borrow();
assert_eq!(obj_ref.method().unwrap(), 42);
let d = [("obj", obj)].into_py_dict(py);
py.run("assert obj.method() == 42", None, Some(d)).unwrap();
py.run("assert obj.method.__doc__ == 'Test method'", None, Some(d))
.unwrap();
py_assert!(py, obj, "obj.method() == 42");
py_assert!(py, obj, "obj.add_other(obj) == 84");
py_assert!(py, obj, "obj.method.__doc__ == 'Test method'");
}

#[pyclass]
Expand Down

0 comments on commit f5f2e84

Please sign in to comment.