Skip to content

Commit

Permalink
Improve diagnostic for invalid function passed to from_py_with
Browse files Browse the repository at this point in the history
  • Loading branch information
mejrs committed Jan 6, 2025
1 parent 93823d2 commit edf46de
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 44 deletions.
91 changes: 54 additions & 37 deletions pyo3-macros-backend/src/frompyobject.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute};
use crate::utils::Ctx;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use quote::{format_ident, quote, quote_spanned};
use syn::{
ext::IdentExt,
parenthesized,
Expand Down Expand Up @@ -264,31 +264,42 @@ impl<'a> Container<'a> {
let struct_name = self.name();
if let Some(ident) = field_ident {
let field_name = ident.to_string();
match from_py_with {
None => quote! {
if let Some(FromPyWithAttribute {
kw,
value: expr_path,
}) = from_py_with
{
let extractor = quote_spanned! { kw.span =>
{ let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
};
quote! {
Ok(#self_ty {
#ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
#ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, obj, #struct_name, #field_name)?
})
},
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! {
}
} else {
quote! {
Ok(#self_ty {
#ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, #field_name)?
#ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
})
},
}
}
} else {
match from_py_with {
None => quote! {
if let Some(FromPyWithAttribute {
kw,
value: expr_path,
}) = from_py_with
{
let extractor = quote_spanned! { kw.span =>
{ let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
};
quote! {
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, obj, #struct_name, 0).map(#self_ty)
}
} else {
quote! {
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
},

Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! {
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, obj, #struct_name, 0).map(#self_ty)
},
}
}
}
}
Expand All @@ -301,16 +312,20 @@ impl<'a> Container<'a> {
.map(|i| format_ident!("arg{}", i))
.collect();
let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
match &field.from_py_with {
None => quote!(
if let Some(FromPyWithAttribute {
kw,
value: expr_path, ..
}) = &field.from_py_with {
let extractor = quote_spanned! { kw.span =>
{ let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
};
quote! {
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, &#ident, #struct_name, #index)?
}
} else {
quote!{
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
),
Some(FromPyWithAttribute {
value: expr_path, ..
}) => quote! (
#pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path as fn(_) -> _, &#ident, #struct_name, #index)?
),
}
}}
});

quote!(
Expand Down Expand Up @@ -346,15 +361,17 @@ impl<'a> Container<'a> {
quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #field_name)))
}
};
let extractor = match &field.from_py_with {
None => {
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?)
}
Some(FromPyWithAttribute {
value: expr_path, ..
}) => {
quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &#getter?, #struct_name, #field_name)?)
}
let extractor = if let Some(FromPyWithAttribute {
kw,
value: expr_path,
}) = &field.from_py_with
{
let extractor = quote_spanned! { kw.span =>
{ let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
};
quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, &#getter?, #struct_name, #field_name)?)
} else {
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?)
};

fields.push(quote!(#ident: #extractor));
Expand Down
10 changes: 7 additions & 3 deletions pyo3-macros-backend/src/params.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::utils::Ctx;
use crate::{
attributes::FromPyWithAttribute,
method::{FnArg, FnSpec, RegularArg},
pyfunction::FunctionSignature,
quotes::some_wrap,
Expand Down Expand Up @@ -248,13 +249,16 @@ pub(crate) fn impl_regular_arg_param(
default = default.map(|tokens| some_wrap(tokens, ctx));
}

if arg.from_py_with.is_some() {
if let Some(FromPyWithAttribute { kw, .. }) = arg.from_py_with {
let extractor = quote_spanned! { kw.span =>
{ let from_py_with: fn(_) -> _ = #from_py_with; from_py_with }
};
if let Some(default) = default {
quote_arg_span! {
#pyo3_path::impl_::extract_argument::from_py_with_with_default(
#arg_value,
#name_str,
#from_py_with as fn(_) -> _,
#extractor,
#[allow(clippy::redundant_closure)]
{
|| #default
Expand All @@ -267,7 +271,7 @@ pub(crate) fn impl_regular_arg_param(
#pyo3_path::impl_::extract_argument::from_py_with(
#unwrap,
#name_str,
#from_py_with as fn(_) -> _,
#extractor,
)?
}
}
Expand Down
14 changes: 10 additions & 4 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::borrow::Cow;
use std::ffi::CString;

use crate::attributes::{NameAttribute, RenamingRule};
use crate::attributes::{FromPyWithAttribute, NameAttribute, RenamingRule};
use crate::method::{CallingConvention, ExtractErrorMode, PyArg};
use crate::params::{impl_regular_arg_param, Holders};
use crate::utils::PythonDoc;
Expand Down Expand Up @@ -1179,14 +1179,20 @@ fn extract_object(
let Ctx { pyo3_path, .. } = ctx;
let name = arg.name().unraw().to_string();

let extract = if let Some(from_py_with) =
arg.from_py_with().map(|from_py_with| &from_py_with.value)
let extract = if let Some(FromPyWithAttribute {
kw,
value: extractor,
}) = arg.from_py_with()
{
let extractor = quote_spanned! { kw.span =>
{ let from_py_with: fn(_) -> _ = #extractor; from_py_with }
};

quote! {
#pyo3_path::impl_::extract_argument::from_py_with(
unsafe { #pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &#source_ptr).0 },
#name,
#from_py_with as fn(_) -> _,
#extractor,
)
}
} else {
Expand Down
7 changes: 7 additions & 0 deletions tests/ui/invalid_argument_attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,11 @@ fn from_py_with_value_not_a_string(#[pyo3(from_py_with = func)] _param: String)
#[pyfunction]
fn from_py_with_repeated(#[pyo3(from_py_with = "func", from_py_with = "func")] _param: String) {}

fn bytes_from_py(bytes: &Bound<'_, pyo3::types::PyBytes>) -> Vec<u8> {
bytes.as_bytes().to_vec()
}

#[pyfunction]
fn f(#[pyo3(from_py_with = "bytes_from_py")] _bytes: Vec<u8>) {}

fn main() {}
11 changes: 11 additions & 0 deletions tests/ui/invalid_argument_attributes.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,14 @@ error: `from_py_with` may only be specified once per argument
|
16 | fn from_py_with_repeated(#[pyo3(from_py_with = "func", from_py_with = "func")] _param: String) {}
| ^^^^^^^^^^^^

error[E0308]: mismatched types
--> tests/ui/invalid_argument_attributes.rs:23:13
|
22 | #[pyfunction]
| ------------- here the type of `from_py_with` is inferred to be `fn(&pyo3::Bound<'_, PyBytes>) -> Vec<u8>`
23 | fn f(#[pyo3(from_py_with = "bytes_from_py")] _bytes: Vec<u8>) {}
| ^^^^^^^^^^^^ expected `PyAny`, found `PyBytes`
|
= note: expected fn pointer `fn(&pyo3::Bound<'_, PyAny>) -> Result<_, PyErr>`
found fn pointer `fn(&pyo3::Bound<'_, PyBytes>) -> Vec<u8>`

0 comments on commit edf46de

Please sign in to comment.