From 5f82beb7f9774e512b6d575fbefabbb7a432e74a Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Tue, 12 Apr 2022 19:37:44 +0100 Subject: [PATCH] WIP: new function signature --- pyo3-macros-backend/src/method.rs | 7 +- pyo3-macros-backend/src/pyfunction.rs | 250 +++++++++++++++------- pytests/src/deprecated_pyfunctions.rs | 66 ++++++ pytests/src/pyfunctions.rs | 3 +- tests/test_compile_error.rs | 1 + tests/ui/invalid_pyfunction_signatures.rs | 11 + 6 files changed, 253 insertions(+), 85 deletions(-) create mode 100644 pytests/src/deprecated_pyfunctions.rs create mode 100644 tests/ui/invalid_pyfunction_signatures.rs diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index ddf2244f199..a85679d22a3 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -3,8 +3,7 @@ use crate::attributes::TextSignatureAttribute; use crate::deprecations::Deprecation; use crate::params::{accept_args_kwargs, impl_arg_params}; -use crate::pyfunction::PyFunctionOptions; -use crate::pyfunction::{PyFunctionArgPyO3Attributes, PyFunctionSignature}; +use crate::pyfunction::{PyFunctionOptions, DeprecatedArgs, PyFunctionArgPyO3Attributes}; use crate::utils::{self, get_pyo3_crate, PythonDoc}; use crate::{deprecations::Deprecations, pyfunction::Argument}; use proc_macro2::{Span, TokenStream}; @@ -716,8 +715,8 @@ fn parse_method_attributes( } }; } else if path.is_ident("args") { - let attrs = PyFunctionSignature::from_meta(&nested)?; - args.extend(attrs.arguments) + let attrs = DeprecatedArgs::from_meta(&nested)?; + args.extend(attrs.0.arguments) } else { new_attrs.push(attr) } diff --git a/pyo3-macros-backend/src/pyfunction.rs b/pyo3-macros-backend/src/pyfunction.rs index 60ad2afe9ae..2284ef70799 100644 --- a/pyo3-macros-backend/src/pyfunction.rs +++ b/pyo3-macros-backend/src/pyfunction.rs @@ -3,20 +3,21 @@ use crate::{ attributes::{ self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute, - FromPyWithAttribute, NameAttribute, TextSignatureAttribute, + FromPyWithAttribute, NameAttribute, TextSignatureAttribute, KeywordAttribute, kw, }, deprecations::Deprecations, method::{self, CallingConvention, FnArg}, pymethod::check_generic, utils::{self, ensure_not_async_fn, get_pyo3_crate}, }; -use proc_macro2::{Span, TokenStream}; -use quote::{format_ident, quote}; +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; use syn::punctuated::Punctuated; use syn::{ext::IdentExt, spanned::Spanned, NestedMeta, Path, Result}; use syn::{ parse::{Parse, ParseBuffer, ParseStream}, token::Comma, + Token, }; #[derive(Debug, Clone, PartialEq)] @@ -30,16 +31,74 @@ pub enum Argument { Kwarg(syn::Path, Option), } -/// The attributes of the pyfunction macro -#[derive(Default)] -pub struct PyFunctionSignature { - pub arguments: Vec, - has_kw: bool, - has_posonly_args: bool, - has_varargs: bool, - has_kwargs: bool, +pub struct Signature { + paren_token: syn::token::Paren, + items: Punctuated, +} + +impl Parse for Signature { + fn parse(input: ParseStream<'_>) -> syn::Result { + let content; + Ok(Signature { + paren_token: syn::parenthesized!(content in input), + items: content.parse_terminated(SignatureItem::parse)?, + }) + } +} + +impl ToTokens for Signature { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.paren_token.surround(tokens, |tokens| { + self.items.to_tokens(tokens) + }) + } +} + +pub struct SignatureItemArgument { + ident: syn::Ident, + eq_and_default: Option<(Token![=], syn::Expr)>, +} + +pub enum SignatureItem { + Argument(SignatureItemArgument), +} + +impl Parse for SignatureItem { + fn parse(input: ParseStream<'_>) -> syn::Result { + input.parse().map(SignatureItem::Argument) + } +} + +impl ToTokens for SignatureItem { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + SignatureItem::Argument(arg) => { + arg.ident.to_tokens(tokens); + if let Some((eq, default)) = &arg.eq_and_default { + eq.to_tokens(tokens); + default.to_tokens(tokens); + } + } + } + } +} + +impl Parse for SignatureItemArgument { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + ident: input.parse()?, + eq_and_default: if input.peek(Token![=]) { + Some((input.parse()?, input.parse()?)) + } else { + None + }, + }) + } } + +pub type SignatureAttribute = KeywordAttribute; + #[derive(Clone, Debug)] pub struct PyFunctionArgPyO3Attributes { pub from_py_with: Option, @@ -86,16 +145,63 @@ impl PyFunctionArgPyO3Attributes { } } -impl syn::parse::Parse for PyFunctionSignature { +#[derive(Default)] +pub struct FunctionSignature { + pub arguments: Vec, + has_kw: bool, + has_posonly_args: bool, + has_varargs: bool, + has_kwargs: bool, +} + +impl FunctionSignature { + fn from_attribute_and_arguments(attribute: SignatureAttribute, args: &[FnArg<'_>]) -> syn::Result { + let mut reconciled = FunctionSignature { + arguments: Vec::with_capacity(args.len()), + has_kw: false, + has_posonly_args: false, + has_varargs: false, + has_kwargs: false, + }; + + let signature = &attribute.value; + + let mut args_iter = args.iter(); + for item in &signature.items { + match args_iter.next() { + Some(arg) => { + let name = arg.name; + reconciled.arguments.push(Argument::Arg(syn::parse_quote! {#name}, None)); + }, + None => bail_spanned!( + item.span() => "signature entry does not have a corresponding function argument" + ) + } + } + + if let Some(arg) = args_iter.next() { + bail_spanned!( + attribute.span() => format!("missing signature entry for argument `{}`", arg.name) + ); + } + + Ok(Self::default()) + } +} + +pub struct DeprecatedArgs(pub FunctionSignature); + +// Deprecated parsing mode for the signature +impl syn::parse::Parse for DeprecatedArgs { fn parse(input: &ParseBuffer<'_>) -> syn::Result { let attr = Punctuated::::parse_terminated(input)?; Self::from_meta(&attr) } } -impl PyFunctionSignature { +impl DeprecatedArgs { pub fn from_meta<'a>(iter: impl IntoIterator) -> syn::Result { - let mut slf = PyFunctionSignature::default(); + let mut slf = DeprecatedArgs(FunctionSignature::default()); for item in iter { slf.add_item(item)? @@ -124,23 +230,23 @@ impl PyFunctionSignature { syn::Lit::Str(lits) if lits.value() == "*" => { // "*" self.vararg_is_ok(item)?; - self.has_varargs = true; - self.arguments.push(Argument::VarArgsSeparator); + self.0.has_varargs = true; + self.0.arguments.push(Argument::VarArgsSeparator); Ok(()) } syn::Lit::Str(lits) if lits.value() == "/" => { // "/" self.posonly_arg_is_ok(item)?; - self.has_posonly_args = true; + self.0.has_posonly_args = true; // any arguments _before_ this become positional-only - self.arguments.iter_mut().for_each(|a| { + self.0.arguments.iter_mut().for_each(|a| { if let Argument::Arg(path, name) = a { *a = Argument::PosOnlyArg(path.clone(), name.clone()); } else { unreachable!(); } }); - self.arguments.push(Argument::PosOnlyArgsSeparator); + self.0.arguments.push(Argument::PosOnlyArgsSeparator); Ok(()) } _ => bail_spanned!(item.span() => "expected \"/\" or \"*\""), @@ -149,20 +255,20 @@ impl PyFunctionSignature { fn add_work(&mut self, item: &NestedMeta, path: &Path) -> syn::Result<()> { ensure_spanned!( - !(self.has_kw || self.has_kwargs), + !(self.0.has_kw || self.0.has_kwargs), item.span() => "positional argument or varargs(*) not allowed after keyword arguments" ); - if self.has_varargs { - self.arguments.push(Argument::Kwarg(path.clone(), None)); + if self.0.has_varargs { + self.0.arguments.push(Argument::Kwarg(path.clone(), None)); } else { - self.arguments.push(Argument::Arg(path.clone(), None)); + self.0.arguments.push(Argument::Arg(path.clone(), None)); } Ok(()) } fn posonly_arg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> { ensure_spanned!( - !(self.has_posonly_args || self.has_kwargs || self.has_varargs), + !(self.0.has_posonly_args || self.0.has_kwargs || self.0.has_varargs), item.span() => "/ is not allowed after /, varargs(*), or kwargs(**)" ); Ok(()) @@ -170,7 +276,7 @@ impl PyFunctionSignature { fn vararg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> { ensure_spanned!( - !(self.has_kwargs || self.has_varargs), + !(self.0.has_kwargs || self.0.has_varargs), item.span() => "* is not allowed after varargs(*) or kwargs(**)" ); Ok(()) @@ -178,7 +284,7 @@ impl PyFunctionSignature { fn kw_arg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> { ensure_spanned!( - !self.has_kwargs, + !self.0.has_kwargs, item.span() => "keyword argument or kwargs(**) is not allowed after kwargs(**)" ); Ok(()) @@ -191,13 +297,13 @@ impl PyFunctionSignature { value: String, ) -> syn::Result<()> { self.kw_arg_is_ok(item)?; - if self.has_varargs { + if self.0.has_varargs { // kw only - self.arguments + self.0.arguments .push(Argument::Kwarg(name.clone(), Some(value))); } else { - self.has_kw = true; - self.arguments + self.0.has_kw = true; + self.0.arguments .push(Argument::Arg(name.clone(), Some(value))); } Ok(()) @@ -209,13 +315,13 @@ impl PyFunctionSignature { if litstr.value() == "*" { // args="*" self.vararg_is_ok(item)?; - self.has_varargs = true; - self.arguments.push(Argument::VarArgs(nv.path.clone())); + self.0.has_varargs = true; + self.0.arguments.push(Argument::VarArgs(nv.path.clone())); } else if litstr.value() == "**" { // kwargs="**" self.kw_arg_is_ok(item)?; - self.has_kwargs = true; - self.arguments.push(Argument::KeywordArgs(nv.path.clone())); + self.0.has_kwargs = true; + self.0.arguments.push(Argument::KeywordArgs(nv.path.clone())); } else { self.add_nv_common(item, &nv.path, litstr.value())?; } @@ -236,7 +342,8 @@ impl PyFunctionSignature { pub struct PyFunctionOptions { pub pass_module: Option, pub name: Option, - pub signature: Option, + pub deprecated_args: Option, + pub signature: Option, pub text_signature: Option, pub deprecations: Deprecations, pub krate: Option, @@ -264,7 +371,7 @@ impl Parse for PyFunctionOptions { // If not recognised attribute, this is "legacy" pyfunction syntax #[pyfunction(a, b)] // // TODO deprecate in favour of #[pyfunction(signature = (a, b), name = "foo")] - options.signature = Some(input.parse()?); + options.deprecated_args = Some(input.parse()?); break; } } @@ -276,7 +383,7 @@ impl Parse for PyFunctionOptions { pub enum PyFunctionOption { Name(NameAttribute), PassModule(attributes::kw::pass_module), - Signature(PyFunctionSignature), + Signature(SignatureAttribute), TextSignature(TextSignatureAttribute), Crate(CrateAttribute), } @@ -311,51 +418,28 @@ impl PyFunctionOptions { &mut self, attrs: impl IntoIterator, ) -> Result<()> { - for attr in attrs { - match attr { - PyFunctionOption::Name(name) => self.set_name(name)?, - PyFunctionOption::PassModule(kw) => { - ensure_spanned!( - self.pass_module.is_none(), - kw.span() => "`pass_module` may only be specified once" - ); - self.pass_module = Some(kw); - } - PyFunctionOption::Signature(signature) => { + macro_rules! set_option { + ($key:ident) => { + { ensure_spanned!( - self.signature.is_none(), - // FIXME: improve the span of this error message - Span::call_site() => "`signature` may only be specified once" + self.$key.is_none(), + $key.span() => concat!("`", stringify!($key), "` may only be specified once") ); - self.signature = Some(signature); - } - PyFunctionOption::TextSignature(text_signature) => { - ensure_spanned!( - self.text_signature.is_none(), - text_signature.kw.span() => "`text_signature` may only be specified once" - ); - self.text_signature = Some(text_signature); - } - PyFunctionOption::Crate(path) => { - ensure_spanned!( - self.krate.is_none(), - path.span() => "`crate` may only be specified once" - ); - self.krate = Some(path); + self.$key = Some($key); } + }; + } + for attr in attrs { + match attr { + PyFunctionOption::Name(name) => set_option!(name), + PyFunctionOption::PassModule(pass_module) => set_option!(pass_module), + PyFunctionOption::Signature(signature) => set_option!(signature), + PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature), + PyFunctionOption::Crate(krate) => set_option!(krate), } } Ok(()) } - - pub fn set_name(&mut self, name: NameAttribute) -> Result<()> { - ensure_spanned!( - self.name.is_none(), - name.span() => "`name` may only be specified once" - ); - self.name = Some(name); - Ok(()) - } } pub fn build_py_function( @@ -379,8 +463,6 @@ pub fn impl_wrap_pyfunction( .name .map_or_else(|| func.sig.ident.unraw(), |name| name.value.0); - let signature = options.signature.unwrap_or_default(); - let mut arguments = func .sig .inputs @@ -401,6 +483,14 @@ pub fn impl_wrap_pyfunction( ); } + let signature = if let Some(signature) = options.signature { + FunctionSignature::from_attribute_and_arguments(signature, &arguments)? + } else if let Some(deprecated_args) = options.deprecated_args { + deprecated_args.0 + } else { + FunctionSignature::default() + }; + let ty = method::get_return_info(&func.sig.output); let doc = utils::get_doc( @@ -487,14 +577,14 @@ fn type_is_pymodule(ty: &syn::Type) -> bool { #[cfg(test)] mod tests { - use super::{Argument, PyFunctionSignature}; + use super::{Argument, DeprecatedArgs}; use proc_macro2::TokenStream; use quote::quote; use syn::parse_quote; fn items(input: TokenStream) -> syn::Result> { - let py_fn_attr: PyFunctionSignature = syn::parse2(input)?; - Ok(py_fn_attr.arguments) + let py_fn_attr: DeprecatedArgs = syn::parse2(input)?; + Ok(py_fn_attr.0.arguments) } #[test] diff --git a/pytests/src/deprecated_pyfunctions.rs b/pytests/src/deprecated_pyfunctions.rs new file mode 100644 index 00000000000..a259aea06d6 --- /dev/null +++ b/pytests/src/deprecated_pyfunctions.rs @@ -0,0 +1,66 @@ +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyTuple}; + +#[pyfunction] +fn none() {} + +#[pyfunction(b = "\"bar\"", "*", c = "None")] +fn simple<'a>(a: i32, b: &'a str, c: Option<&'a PyDict>) -> (i32, &'a str, Option<&'a PyDict>) { + (a, b, c) +} + +#[pyfunction(b = "\"bar\"", args = "*", c = "None")] +fn simple_args<'a>( + a: i32, + b: &'a str, + c: Option<&'a PyDict>, + args: &'a PyTuple, +) -> (i32, &'a str, &'a PyTuple, Option<&'a PyDict>) { + (a, b, args, c) +} + +#[pyfunction(b = "\"bar\"", c = "None", kwargs = "**")] +fn simple_kwargs<'a>( + a: i32, + b: &'a str, + c: Option<&'a PyDict>, + kwargs: Option<&'a PyDict>, +) -> (i32, &'a str, Option<&'a PyDict>, Option<&'a PyDict>) { + (a, b, c, kwargs) +} + +#[pyfunction(a, b = "\"bar\"", args = "*", c = "None", kwargs = "**")] +fn simple_args_kwargs<'a>( + a: i32, + b: &'a str, + args: &'a PyTuple, + c: Option<&'a PyDict>, + kwargs: Option<&'a PyDict>, +) -> ( + i32, + &'a str, + &'a PyTuple, + Option<&'a PyDict>, + Option<&'a PyDict>, +) { + (a, b, args, c, kwargs) +} + +#[pyfunction(args = "*", kwargs = "**")] +fn args_kwargs<'a>( + args: &'a PyTuple, + kwargs: Option<&'a PyDict>, +) -> (&'a PyTuple, Option<&'a PyDict>) { + (args, kwargs) +} + +#[pymodule] +pub fn pyfunctions(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(none, m)?)?; + m.add_function(wrap_pyfunction!(simple, m)?)?; + m.add_function(wrap_pyfunction!(simple_args, m)?)?; + m.add_function(wrap_pyfunction!(simple_kwargs, m)?)?; + m.add_function(wrap_pyfunction!(simple_args_kwargs, m)?)?; + m.add_function(wrap_pyfunction!(args_kwargs, m)?)?; + Ok(()) +} diff --git a/pytests/src/pyfunctions.rs b/pytests/src/pyfunctions.rs index a259aea06d6..3a35d2495d7 100644 --- a/pytests/src/pyfunctions.rs +++ b/pytests/src/pyfunctions.rs @@ -29,7 +29,8 @@ fn simple_kwargs<'a>( (a, b, c, kwargs) } -#[pyfunction(a, b = "\"bar\"", args = "*", c = "None", kwargs = "**")] +#[pyfunction] +#[pyo3(signature = (a, b = "\"bar\"", args = "*", c = "None", kwargs = "**"))] fn simple_args_kwargs<'a>( a: i32, b: &'a str, diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 371748db48e..33a49426b35 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -35,6 +35,7 @@ fn _test_compile_errors() { t.compile_fail("tests/ui/invalid_pyclass_args.rs"); t.compile_fail("tests/ui/invalid_pyclass_enum.rs"); t.compile_fail("tests/ui/invalid_pyclass_item.rs"); + t.compile_fail("tests/ui/invalid_pyfunction_signatures.rs"); #[cfg(not(Py_LIMITED_API))] t.compile_fail("tests/ui/invalid_pymethods_buffer.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); diff --git a/tests/ui/invalid_pyfunction_signatures.rs b/tests/ui/invalid_pyfunction_signatures.rs new file mode 100644 index 00000000000..6cafd073757 --- /dev/null +++ b/tests/ui/invalid_pyfunction_signatures.rs @@ -0,0 +1,11 @@ +use pyo3::prelude::*; + +#[pyfunction] +#[pyo3(signature = ())] +fn function_with_one_argument_empty_signature(_x: i32) {} + +#[pyfunction] +#[pyo3(signature = (x))] +fn function_with_one_entry_signature_no_args() {} + +fn main() {}