Skip to content

Commit

Permalink
macros: optimize generated code for #[derive(FromPyObject)]
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Dec 22, 2021
1 parent ff6fb5d commit 492b7e4
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 86 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `PyErr::new_type` now takes an optional docstring and now returns `PyResult<Py<PyType>>` rather than a `ffi::PyTypeObject` pointer.
- The `create_exception!` macro can now take an optional docstring. This docstring, if supplied, is visible to users (with `.__doc__` and `help()`) and
accompanies your error type in your crate's documentation.

- Improve performance and error messages for `#[derive(FromPyObject)]` for enums. [#2068](https://github.com/PyO3/pyo3/pull/2068)

### Removed

- Remove all functionality deprecated in PyO3 0.14. [#2007](https://github.com/PyO3/pyo3/pull/2007)
Expand Down
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ harness = false
[[bench]]
name = "bench_frompyobject"
harness = false
required-features = ["macros"]

[[bench]]
name = "bench_gil"
Expand All @@ -106,6 +107,7 @@ harness = false
[[bench]]
name = "bench_pyclass"
harness = false
required-features = ["macros"]

[[bench]]
name = "bench_pyobject"
Expand Down
2 changes: 1 addition & 1 deletion benches/bench_frompyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use pyo3::{prelude::*, types::PyString};
enum ManyTypes {
Int(i32),
Bytes(Vec<u8>),
String(String)
String(String),
}

fn enum_from_pyobject(b: &mut Bencher) {
Expand Down
79 changes: 32 additions & 47 deletions benches/bench_pyclass.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,49 @@
#[cfg(feature = "macros")]
use criterion::{criterion_group, criterion_main, Criterion};
use pyo3::{class::PyObjectProtocol, prelude::*, type_object::LazyStaticType};

#[cfg(feature = "macros")]
mod m {
use pyo3::{class::PyObjectProtocol, prelude::*, type_object::LazyStaticType};
/// This is a feature-rich class instance used to benchmark various parts of the pyclass lifecycle.
#[pyclass]
struct MyClass {
#[pyo3(get, set)]
elements: Vec<i32>,
}

/// This is a feature-rich class instance used to benchmark various parts of the pyclass lifecycle.
#[pyclass]
struct MyClass {
#[pyo3(get, set)]
elements: Vec<i32>,
#[pymethods]
impl MyClass {
#[new]
fn new(elements: Vec<i32>) -> Self {
Self { elements }
}

#[pymethods]
impl MyClass {
#[new]
fn new(elements: Vec<i32>) -> Self {
Self { elements }
}

fn __call__(&mut self, new_element: i32) -> usize {
self.elements.push(new_element);
self.elements.len()
}
fn __call__(&mut self, new_element: i32) -> usize {
self.elements.push(new_element);
self.elements.len()
}
}

#[pyproto]
impl PyObjectProtocol for MyClass {
/// A basic __str__ implementation.
fn __str__(&self) -> &'static str {
"MyClass"
}
#[pyproto]
impl PyObjectProtocol for MyClass {
/// A basic __str__ implementation.
fn __str__(&self) -> &'static str {
"MyClass"
}
}

pub fn first_time_init(b: &mut criterion::Bencher) {
let gil = Python::acquire_gil();
let py = gil.python();
b.iter(|| {
// This is using an undocumented internal PyO3 API to measure pyclass performance; please
// don't use this in your own code!
let ty = LazyStaticType::new();
ty.get_or_init::<MyClass>(py);
});
}
pub fn first_time_init(b: &mut criterion::Bencher) {
let gil = Python::acquire_gil();
let py = gil.python();
b.iter(|| {
// This is using an undocumented internal PyO3 API to measure pyclass performance; please
// don't use this in your own code!
let ty = LazyStaticType::new();
ty.get_or_init::<MyClass>(py);
});
}

#[cfg(feature = "macros")]
fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("first_time_init", m::first_time_init);
c.bench_function("first_time_init", first_time_init);
}

#[cfg(feature = "macros")]
criterion_group!(benches, criterion_benchmark);

#[cfg(feature = "macros")]
criterion_main!(benches);

#[cfg(not(feature = "macros"))]
fn main() {
unimplemented!(
"benchmarking `bench_pyclass` is only available with the `macros` feature enabled"
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,38 +54,39 @@ impl<'a> Enum<'a> {
/// Build derivation body for enums.
fn build(&self) -> TokenStream {
let mut var_extracts = Vec::new();
let mut error_names = String::new();
for (i, var) in self.variants.iter().enumerate() {
let mut variant_names = Vec::new();
let mut error_names = Vec::new();
for var in &self.variants {
let struct_derive = var.build();
let ext = quote!(
let ext = quote!({
let maybe_ret = || -> _pyo3::PyResult<Self> {
#struct_derive
}();

match maybe_ret {
ok @ ::std::result::Result::Ok(_) => return ok,
::std::result::Result::Err(err) => {
let py = _pyo3::PyNativeType::py(obj);
err_reasons.push_str(&::std::format!("{}\n", err.value(py).str()?));
}
::std::result::Result::Err(err) => err
}
);
});

var_extracts.push(ext);
if i > 0 {
error_names.push_str(" | ");
}
error_names.push_str(&var.err_name);
variant_names.push(var.path.segments.last().unwrap().ident.to_string());
error_names.push(&var.err_name);
}
let ty_name = self.enum_ident.to_string();
quote!(
let mut err_reasons = ::std::string::String::new();
#(#var_extracts)*
let err_msg = ::std::format!("failed to extract enum {} ('{}')\n{}",
#ty_name,
#error_names,
&err_reasons);
::std::result::Result::Err(_pyo3::exceptions::PyTypeError::new_err(err_msg))
let errors = [
#(#var_extracts),*
];
::std::result::Result::Err(
_pyo3::impl_::frompyobject::failed_to_extract_enum(
obj.py(),
#ty_name,
&[#(#variant_names),*],
&[#(#error_names),*],
&errors
)
)
)
}
}
Expand Down Expand Up @@ -216,13 +217,8 @@ impl<'a> Container<'a> {
new_err
})?})
)
} else {
let error_msg = if self.is_enum_variant {
let variant_name = &self.path.segments.last().unwrap();
format!("- variant {} ({})", quote!(#variant_name), &self.err_name)
} else {
format!("failed to extract inner field of {}", quote!(#self_ty))
};
} else if !self.is_enum_variant {
let error_msg = format!("failed to extract inner field of {}", quote!(#self_ty));
quote!(
::std::result::Result::Ok(#self_ty(obj.extract().map_err(|err| {
let py = _pyo3::PyNativeType::py(obj);
Expand All @@ -232,6 +228,8 @@ impl<'a> Container<'a> {
_pyo3::exceptions::PyTypeError::new_err(err_msg)
})?))
)
} else {
quote!(obj.extract().map(#self_ty))
}
}

Expand Down
4 changes: 2 additions & 2 deletions pyo3-macros-backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod utils;
mod attributes;
mod defs;
mod deprecations;
mod from_pyobject;
mod frompyobject;
mod konst;
mod method;
mod module;
Expand All @@ -23,7 +23,7 @@ mod pyimpl;
mod pymethod;
mod pyproto;

pub use from_pyobject::build_derive_from_pyobject;
pub use frompyobject::build_derive_from_pyobject;
pub use module::{process_functions_in_module, py_init, PyModuleOptions};
pub use pyclass::{build_py_class, build_py_enum, PyClassArgs};
pub use pyfunction::{build_py_function, PyFunctionOptions};
Expand Down
2 changes: 2 additions & 0 deletions src/impl_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
pub mod deprecations;
pub mod freelist;
#[doc(hidden)]
pub mod frompyobject;
26 changes: 26 additions & 0 deletions src/impl_/frompyobject.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use crate::{exceptions::PyTypeError, PyErr, Python};

#[cold]
pub fn failed_to_extract_enum(
py: Python,
type_name: &str,
variant_names: &[&str],
error_names: &[&str],
errors: &[PyErr],
) -> PyErr {
let mut err_msg = format!(
"failed to extract enum {} ('{}')",
type_name,
error_names.join(" | ")
);
for ((variant_name, error_name), error) in variant_names.iter().zip(error_names).zip(errors) {
err_msg.push('\n');
err_msg.push_str(&format!(
"- variant {variant_name} ({error_name}): {error_msg}",
variant_name = variant_name,
error_name = error_name,
error_msg = error.value(py).str().unwrap().to_str().unwrap(),
));
}
PyTypeError::new_err(err_msg)
}
54 changes: 45 additions & 9 deletions tests/test_frompyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,6 @@ pub enum Foo<'a> {
#[pyo3(item("foo"))]
a: String,
},
#[pyo3(transparent)]
CatchAll(&'a PyAny),
}

#[pyclass]
Expand Down Expand Up @@ -381,15 +379,52 @@ fn test_enum() {
Foo::StructWithGetItemArg { a } => assert_eq!(a, "test"),
_ => panic!("Expected extracting Foo::StructWithGetItemArg, got {:?}", f),
}
});
}

#[test]
fn test_enum_error() {
Python::with_gil(|py| {
let dict = PyDict::new(py);
let f = Foo::extract(dict.as_ref()).expect("Failed to extract Foo from dict");
let err = Foo::extract(dict.as_ref()).unwrap_err();
assert_eq!(
err.to_string(),
"\
TypeError: failed to extract enum Foo ('TupleVar | StructVar | TransparentTuple | TransparentStructVar | StructVarGetAttrArg | StructWithGetItem | StructWithGetItemArg')
- variant TupleVar (TupleVar): 'dict' object cannot be converted to 'PyTuple'
- variant StructVar (StructVar): 'dict' object has no attribute 'test'
- variant TransparentTuple (TransparentTuple): 'dict' object cannot be interpreted as an integer
- variant TransparentStructVar (TransparentStructVar): failed to extract field Foo :: TransparentStructVar.a
- variant StructVarGetAttrArg (StructVarGetAttrArg): 'dict' object has no attribute 'bla'
- variant StructWithGetItem (StructWithGetItem): 'a'
- variant StructWithGetItemArg (StructWithGetItemArg): 'foo'"
);
});
}

#[derive(Debug, FromPyObject)]
enum EnumWithCatchAll<'a> {
#[pyo3(transparent)]
Foo(Foo<'a>),
#[pyo3(transparent)]
CatchAll(&'a PyAny),
}

#[test]
fn test_enum_catch_all() {
Python::with_gil(|py| {
let dict = PyDict::new(py);
let f = EnumWithCatchAll::extract(dict.as_ref())
.expect("Failed to extract EnumWithCatchAll from dict");
match f {
Foo::CatchAll(any) => {
EnumWithCatchAll::CatchAll(any) => {
let d = <&PyDict>::extract(any).expect("Expected pydict");
assert!(d.is_empty());
}
_ => panic!("Expected extracting Foo::CatchAll, got {:?}", f),
_ => panic!(
"Expected extracting EnumWithCatchAll::CatchAll, got {:?}",
f
),
}
});
}
Expand All @@ -412,10 +447,11 @@ fn test_err_rename() {
assert!(f.is_err());
assert_eq!(
f.unwrap_err().to_string(),
"TypeError: failed to extract enum Bar (\'str | uint | int\')\n- variant A (str): \
\'dict\' object cannot be converted to \'PyString\'\n- variant B (uint): \'dict\' object \
cannot be interpreted as an integer\n- variant C (int): \'dict\' object cannot be \
interpreted as an integer\n"
"\
TypeError: failed to extract enum Bar (\'str | uint | int\')
- variant A (str): \'dict\' object cannot be converted to \'PyString\'
- variant B (uint): \'dict\' object cannot be interpreted as an integer
- variant C (int): \'dict\' object cannot be interpreted as an integer"
);
});
}
Expand Down

0 comments on commit 492b7e4

Please sign in to comment.