Skip to content

Commit

Permalink
Collecting multiple attribute error (#4243)
Browse files Browse the repository at this point in the history
* collecting multiple errors

* collecting errors from different fileds

* adding changelog

* Adding UI test

* refactoring

* Update pyo3-macros-backend/src/attributes.rs

Co-authored-by: David Hewitt <mail@davidhewitt.dev>

* Update newsfragments/4243.changed.md

Co-authored-by: David Hewitt <mail@davidhewitt.dev>

* Update tests/ui/invalid_pyclass_args.rs

Co-authored-by: David Hewitt <mail@davidhewitt.dev>

* using pural for names

* get rid of internidiate field_options_res

* reset ordering

---------

Co-authored-by: David Hewitt <mail@davidhewitt.dev>
  • Loading branch information
Cheukting and davidhewitt committed Jul 20, 2024
1 parent 5ac5cef commit a84dae0
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 80 deletions.
1 change: 1 addition & 0 deletions newsfragments/4243.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Report multiple errors from `#[pyclass]` and `#[pyo3(..)]` attributes.
38 changes: 33 additions & 5 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,13 +384,41 @@ pub fn take_attributes(

pub fn take_pyo3_options<T: Parse>(attrs: &mut Vec<syn::Attribute>) -> Result<Vec<T>> {
let mut out = Vec::new();
take_attributes(attrs, |attr| {
if let Some(options) = get_pyo3_options(attr)? {
out.extend(options);
let mut all_errors = ErrorCombiner(None);
take_attributes(attrs, |attr| match get_pyo3_options(attr) {
Ok(result) => {
if let Some(options) = result {
out.extend(options);
Ok(true)
} else {
Ok(false)
}
}
Err(err) => {
all_errors.combine(err);
Ok(true)
} else {
Ok(false)
}
})?;
all_errors.ensure_empty()?;
Ok(out)
}

pub struct ErrorCombiner(pub Option<syn::Error>);

impl ErrorCombiner {
pub fn combine(&mut self, error: syn::Error) {
if let Some(existing) = &mut self.0 {
existing.combine(error);
} else {
self.0 = Some(error);
}
}

pub fn ensure_empty(self) -> Result<()> {
if let Some(error) = self.0 {
Err(error)
} else {
Ok(())
}
}
}
39 changes: 27 additions & 12 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ use syn::{parse_quote, parse_quote_spanned, spanned::Spanned, ImplItemFn, Result

use crate::attributes::kw::frozen;
use crate::attributes::{
self, kw, take_pyo3_options, CrateAttribute, ExtendsAttribute, FreelistAttribute,
ModuleAttribute, NameAttribute, NameLitStr, RenameAllAttribute, StrFormatterAttribute,
self, kw, take_pyo3_options, CrateAttribute, ErrorCombiner, ExtendsAttribute,
FreelistAttribute, ModuleAttribute, NameAttribute, NameLitStr, RenameAllAttribute,
StrFormatterAttribute,
};
use crate::konst::{ConstAttributes, ConstSpec};
use crate::method::{FnArg, FnSpec, PyArg, RegularArg};
Expand Down Expand Up @@ -252,23 +253,35 @@ pub fn build_py_class(
)
);

let mut all_errors = ErrorCombiner(None);

let mut field_options: Vec<(&syn::Field, FieldPyO3Options)> = match &mut class.fields {
syn::Fields::Named(fields) => fields
.named
.iter_mut()
.map(|field| {
FieldPyO3Options::take_pyo3_options(&mut field.attrs)
.map(move |options| (&*field, options))
})
.collect::<Result<_>>()?,
.filter_map(
|field| match FieldPyO3Options::take_pyo3_options(&mut field.attrs) {
Ok(options) => Some((&*field, options)),
Err(e) => {
all_errors.combine(e);
None
}
},
)
.collect::<Vec<_>>(),
syn::Fields::Unnamed(fields) => fields
.unnamed
.iter_mut()
.map(|field| {
FieldPyO3Options::take_pyo3_options(&mut field.attrs)
.map(move |options| (&*field, options))
})
.collect::<Result<_>>()?,
.filter_map(
|field| match FieldPyO3Options::take_pyo3_options(&mut field.attrs) {
Ok(options) => Some((&*field, options)),
Err(e) => {
all_errors.combine(e);
None
}
},
)
.collect::<Vec<_>>(),
syn::Fields::Unit => {
if let Some(attr) = args.options.set_all {
return Err(syn::Error::new_spanned(attr, UNIT_SET));
Expand All @@ -281,6 +294,8 @@ pub fn build_py_class(
}
};

all_errors.ensure_empty()?;

if let Some(attr) = args.options.get_all {
for (_, FieldPyO3Options { get, .. }) in &mut field_options {
if let Some(old_get) = get.replace(Annotated::Struct(attr)) {
Expand Down
31 changes: 19 additions & 12 deletions tests/ui/invalid_pyclass_args.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::fmt::{Display, Formatter};
use pyo3::prelude::*;
use std::fmt::{Display, Formatter};

#[pyclass(extend=pyo3::types::PyDict)]
struct TypoIntheKey {}
Expand Down Expand Up @@ -74,7 +74,16 @@ impl HashOptAndManualHash {

#[pyclass(ord)]
struct InvalidOrderedStruct {
inner: i32
inner: i32,
}

#[pyclass]
struct MultipleErrors {
#[pyo3(foo)]
#[pyo3(blah)]
x: i32,
#[pyo3(pop)]
y: i32,
}

#[pyclass(str)]
Expand All @@ -88,9 +97,7 @@ impl Display for StrOptAndManualStr {

#[pymethods]
impl StrOptAndManualStr {
fn __str__(
&self,
) -> String {
fn __str__(&self) -> String {
todo!()
}
}
Expand Down Expand Up @@ -123,45 +130,45 @@ pub struct Point2 {
#[derive(PartialEq)]
struct Coord3(u32, u32, u32);

#[pyclass(name = "aaa", str="unsafe: {unsafe_variable}")]
#[pyclass(name = "aaa", str = "unsafe: {unsafe_variable}")]
struct StructRenamingWithStrFormatter {
#[pyo3(name = "unsafe", get, set)]
unsafe_variable: usize,
}

#[pyclass(name = "aaa", str="unsafe: {unsafe_variable}")]
#[pyclass(name = "aaa", str = "unsafe: {unsafe_variable}")]
struct StructRenamingWithStrFormatter2 {
unsafe_variable: usize,
}

#[pyclass(str="unsafe: {unsafe_variable}")]
#[pyclass(str = "unsafe: {unsafe_variable}")]
struct StructRenamingWithStrFormatter3 {
#[pyo3(name = "unsafe", get, set)]
unsafe_variable: usize,
}

#[pyclass(rename_all = "SCREAMING_SNAKE_CASE", str="{a_a}, {b_b}, {c_d_e}")]
#[pyclass(rename_all = "SCREAMING_SNAKE_CASE", str = "{a_a}, {b_b}, {c_d_e}")]
struct RenameAllVariantsStruct {
a_a: u32,
b_b: u32,
c_d_e: String,
}

#[pyclass(str="{:?}")]
#[pyclass(str = "{:?}")]
#[derive(Debug)]
struct StructWithNoMember {
a: String,
b: String,
}

#[pyclass(str="{}")]
#[pyclass(str = "{}")]
#[derive(Debug)]
struct StructWithNoMember2 {
a: String,
b: String,
}

#[pyclass(eq, str="Stuff...")]
#[pyclass(eq, str = "Stuff...")]
#[derive(Debug, PartialEq)]
pub enum MyEnumInvalidStrFmt {
Variant,
Expand Down
Loading

0 comments on commit a84dae0

Please sign in to comment.