Skip to content

Commit

Permalink
allow **kwargs to take arguments which conflict with positional-only …
Browse files Browse the repository at this point in the history
…parameters
  • Loading branch information
davidhewitt committed Dec 9, 2022
1 parent 55592af commit e3bfc45
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 89 deletions.
1 change: 1 addition & 0 deletions newsfragments/2800.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow functions taking `**kwargs` to accept keyword arguments which share a name with a positional-only argument (as permitted by PEP 570).
182 changes: 93 additions & 89 deletions src/impl_/extract_argument.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,75 +237,49 @@ impl FunctionDescription {
V: VarargsHandler<'py>,
K: VarkeywordsHandler<'py>,
{
// Safety: Option<&PyAny> has the same memory layout as `*mut ffi::PyObject`
let args = args as *const Option<&PyAny>;
let positional_args_provided = nargs as usize;
let args_slice = if args.is_null() {
&[]
} else {
std::slice::from_raw_parts(args, positional_args_provided)
};

let num_positional_parameters = self.positional_parameter_names.len();

debug_assert!(nargs >= 0);
debug_assert!(self.positional_only_parameters <= num_positional_parameters);
debug_assert!(self.required_positional_parameters <= num_positional_parameters);
debug_assert_eq!(
output.len(),
num_positional_parameters + self.keyword_only_parameters.len()
);

let varargs = if positional_args_provided > num_positional_parameters {
let (positional_parameters, varargs) = args_slice.split_at(num_positional_parameters);
output[..num_positional_parameters].copy_from_slice(positional_parameters);
V::handle_varargs_fastcall(py, varargs, self)?
// Handle positional arguments
// Safety: Option<&PyAny> has the same memory layout as `*mut ffi::PyObject`
let args: *const Option<&PyAny> = args.cast();
let positional_args_provided = nargs as usize;
let remaining_positional_args = if args.is_null() {
debug_assert_eq!(positional_args_provided, 0);
&[]
} else {
output[..positional_args_provided].copy_from_slice(args_slice);
V::handle_varargs_fastcall(py, &[], self)?
// Can consume at most the number of positional parameters in the function definition,
// the rest are varargs.
let positional_args_to_consume =
num_positional_parameters.min(positional_args_provided);
let (positional_parameters, remaining) =
std::slice::from_raw_parts(args, positional_args_provided)
.split_at(positional_args_to_consume);
output[..positional_args_to_consume].copy_from_slice(positional_parameters);
remaining
};
let varargs = V::handle_varargs_fastcall(py, remaining_positional_args, self)?;

// Handle keyword arguments
let mut varkeywords = Default::default();
let mut varkeywords = K::Varkeywords::default();
if let Some(kwnames) = py.from_borrowed_ptr_or_opt::<PyTuple>(kwnames) {
let mut positional_only_keyword_arguments = Vec::new();

// Safety: &PyAny has the same memory layout as `*mut ffi::PyObject`
let kwargs =
::std::slice::from_raw_parts((args as *const &PyAny).offset(nargs), kwnames.len());

for (kwarg_name_py, &value) in kwnames.iter().zip(kwargs) {
// All keyword arguments should be UTF8 strings, but we'll check, just in case.
if let Ok(kwarg_name) = kwarg_name_py.downcast::<PyString>()?.to_str() {
// Try to place parameter in keyword only parameters
if let Some(i) = self.find_keyword_parameter_in_keyword_only(kwarg_name) {
if output[i + num_positional_parameters]
.replace(value)
.is_some()
{
return Err(self.multiple_values_for_argument(kwarg_name));
}
continue;
}

// Repeat for positional parameters
if let Some(i) = self.find_keyword_parameter_in_positional(kwarg_name) {
if i < self.positional_only_parameters {
positional_only_keyword_arguments.push(kwarg_name);
} else if output[i].replace(value).is_some() {
return Err(self.multiple_values_for_argument(kwarg_name));
}
continue;
}
};

K::handle_unexpected_keyword(&mut varkeywords, kwarg_name_py, value, self)?
}

if !positional_only_keyword_arguments.is_empty() {
return Err(
self.positional_only_keyword_arguments(&positional_only_keyword_arguments)
);
}
self.handle_kwargs::<K>(
kwnames.iter().zip(kwargs.iter().copied()),
&mut varkeywords,
num_positional_parameters,
output,
)?
}

// Once all inputs have been processed, check that all required arguments have been provided.
Expand Down Expand Up @@ -360,50 +334,80 @@ impl FunctionDescription {
let varargs = V::handle_varargs_tuple(args, self)?;

// Handle keyword arguments
let mut varkeywords = Default::default();
let mut varkeywords = K::Varkeywords::default();
if let Some(kwargs) = kwargs {
let mut positional_only_keyword_arguments = Vec::new();
for (kwarg_name_py, value) in kwargs {
// All keyword arguments should be UTF8 strings, but we'll check, just in case.
if let Ok(kwarg_name) = kwarg_name_py.downcast::<PyString>()?.to_str() {
// Try to place parameter in keyword only parameters
if let Some(i) = self.find_keyword_parameter_in_keyword_only(kwarg_name) {
if output[i + num_positional_parameters]
.replace(value)
.is_some()
{
return Err(self.multiple_values_for_argument(kwarg_name));
}
continue;
self.handle_kwargs::<K>(kwargs, &mut varkeywords, num_positional_parameters, output)?
}

// Once all inputs have been processed, check that all required arguments have been provided.

self.ensure_no_missing_required_positional_arguments(output, args.len())?;
self.ensure_no_missing_required_keyword_arguments(output)?;

Ok((varargs, varkeywords))
}

#[inline]
fn handle_kwargs<'py, K>(
&self,
kwargs: impl IntoIterator<Item = (&'py PyAny, &'py PyAny)>,
varkeywords: &mut K::Varkeywords,
num_positional_parameters: usize,
output: &mut [Option<&'py PyAny>],
) -> PyResult<()>
where
K: VarkeywordsHandler<'py>,
{
debug_assert_eq!(
num_positional_parameters,
self.positional_parameter_names.len()
);
debug_assert_eq!(
output.len(),
num_positional_parameters + self.keyword_only_parameters.len()
);
let mut positional_only_keyword_arguments = Vec::new();
for (kwarg_name_py, value) in kwargs {
// All keyword arguments should be UTF-8 strings, but we'll check, just in case.
// If it isn't, then it will be handled below as a varkeyword (which may raise an
// error if this function doesn't accept **kwargs). Rust source is always UTF-8
// and so all argument names in `#[pyfunction]` signature must be UTF-8.
if let Ok(kwarg_name) = kwarg_name_py.downcast::<PyString>()?.to_str() {
// Try to place parameter in keyword only parameters
if let Some(i) = self.find_keyword_parameter_in_keyword_only(kwarg_name) {
if output[i + num_positional_parameters]
.replace(value)
.is_some()
{
return Err(self.multiple_values_for_argument(kwarg_name));
}
continue;
}

// Repeat for positional parameters
if let Some(i) = self.find_keyword_parameter_in_positional(kwarg_name) {
if i < self.positional_only_parameters {
// Repeat for positional parameters
if let Some(i) = self.find_keyword_parameter_in_positional(kwarg_name) {
if i < self.positional_only_parameters {
// If accepting **kwargs, then it's allowed for the name of the
// kwarg to conflict with a postional-only argument - the value
// will go into **kwargs anyway.
if K::handle_varkeyword(varkeywords, kwarg_name_py, value, self).is_err() {
positional_only_keyword_arguments.push(kwarg_name);
} else if output[i].replace(value).is_some() {
return Err(self.multiple_values_for_argument(kwarg_name));
}
continue;
} else if output[i].replace(value).is_some() {
return Err(self.multiple_values_for_argument(kwarg_name));
}
};

K::handle_unexpected_keyword(&mut varkeywords, kwarg_name_py, value, self)?
}
continue;
}
};

if !positional_only_keyword_arguments.is_empty() {
return Err(
self.positional_only_keyword_arguments(&positional_only_keyword_arguments)
);
}
K::handle_varkeyword(varkeywords, kwarg_name_py, value, self)?
}

// Once all inputs have been processed, check that all required arguments have been provided.

self.ensure_no_missing_required_positional_arguments(output, args.len())?;
self.ensure_no_missing_required_keyword_arguments(output)?;
if !positional_only_keyword_arguments.is_empty() {
return Err(self.positional_only_keyword_arguments(&positional_only_keyword_arguments));
}

Ok((varargs, varkeywords))
Ok(())
}

#[inline]
Expand Down Expand Up @@ -637,10 +641,10 @@ impl<'py> VarargsHandler<'py> for TupleVarargs {
}
}

/// A trait used to control whether to accept unrecognised keywords in FunctionDescription::extract_argument_(method) functions.
/// A trait used to control whether to accept varkeywords in FunctionDescription::extract_argument_(method) functions.
pub trait VarkeywordsHandler<'py> {
type Varkeywords: Default;
fn handle_unexpected_keyword(
fn handle_varkeyword(
varkeywords: &mut Self::Varkeywords,
name: &'py PyAny,
value: &'py PyAny,
Expand All @@ -654,7 +658,7 @@ pub struct NoVarkeywords;
impl<'py> VarkeywordsHandler<'py> for NoVarkeywords {
type Varkeywords = ();
#[inline]
fn handle_unexpected_keyword(
fn handle_varkeyword(
_varkeywords: &mut Self::Varkeywords,
name: &'py PyAny,
_value: &'py PyAny,
Expand All @@ -670,7 +674,7 @@ pub struct DictVarkeywords;
impl<'py> VarkeywordsHandler<'py> for DictVarkeywords {
type Varkeywords = Option<&'py PyDict>;
#[inline]
fn handle_unexpected_keyword(
fn handle_varkeyword(
varkeywords: &mut Self::Varkeywords,
name: &'py PyAny,
value: &'py PyAny,
Expand Down
26 changes: 26 additions & 0 deletions tests/test_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,16 @@ impl MethSignature {
[a.to_object(py), kwargs.to_object(py)].to_object(py)
}

#[pyo3(signature = (a=0, /, **kwargs))]
fn get_optional_pos_only_with_kwargs(
&self,
py: Python<'_>,
a: i32,
kwargs: Option<&PyDict>,
) -> PyObject {
[a.to_object(py), kwargs.to_object(py)].to_object(py)
}

#[pyo3(signature = (*, a = 2, b = 3))]
fn get_kwargs_only_with_defaults(&self, a: i32, b: i32) -> i32 {
a + b
Expand Down Expand Up @@ -961,6 +971,22 @@ fn meth_signature() {
PyTypeError
);

py_run!(
py,
inst,
"assert inst.get_optional_pos_only_with_kwargs() == [0, None]"
);
py_run!(
py,
inst,
"assert inst.get_optional_pos_only_with_kwargs(10) == [10, None]"
);
py_run!(
py,
inst,
"assert inst.get_optional_pos_only_with_kwargs(a=10) == [0, {'a': 10}]"
);

py_run!(py, inst, "assert inst.get_kwargs_only_with_defaults() == 5");
py_run!(
py,
Expand Down

0 comments on commit e3bfc45

Please sign in to comment.