Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support default method implementation #2014

Merged
merged 1 commit into from
Nov 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::attributes::{self, take_pyo3_options, NameAttribute, TextSignatureAttribute};
use crate::deprecations::Deprecations;
use crate::konst::{ConstAttributes, ConstSpec};
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
use crate::pyimpl::{gen_default_slot_impls, gen_py_const, PyClassMethodsType};
use crate::pymethod::{impl_py_getter_def, impl_py_setter_def, PropertyType};
use crate::utils::{self, unwrap_group, PythonDoc};
use proc_macro2::{Span, TokenStream};
Expand Down Expand Up @@ -425,6 +425,27 @@ fn impl_enum_class(
.impl_all();
let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident));

let default_repr_impl = {
let variants_repr = variants.iter().map(|variant| {
let variant_name = variant.ident;
// Assuming all variants are unit variants because they are the only type we support.
let repr = format!("{}.{}", cls, variant_name);
quote! { #cls::#variant_name => #repr, }
});
quote! {
#[doc(hidden)]
#[allow(non_snake_case)]
#[pyo3(name = "__repr__")]
fn __pyo3__repr__(&self) -> &'static str {
match self {
#(#variants_repr)*
_ => unreachable!("Unsupported variant type."),
}
}
}
};

let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl]);
Ok(quote! {

#pytypeinfo
Expand All @@ -433,6 +454,8 @@ fn impl_enum_class(

#descriptors

#default_impls

})
}

Expand Down Expand Up @@ -758,6 +781,9 @@ impl<'a> PyClassImplsBuilder<'a> {
// Implementation which uses dtolnay specialization to load all slots.
use ::pyo3::class::impl_::*;
let collector = PyClassImplCollector::<Self>::new();
// This depends on Python implementation detail;
// an old slot entry will be overriden by newer ones.
visitor(collector.py_class_default_slots());
visitor(collector.object_protocol_slots());
visitor(collector.number_protocol_slots());
visitor(collector.iter_protocol_slots());
Expand Down
41 changes: 41 additions & 0 deletions pyo3-macros-backend/src/pyimpl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,47 @@ pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream {
}
}

pub fn gen_default_slot_impls(cls: &syn::Ident, method_defs: Vec<TokenStream>) -> TokenStream {
// This function uses a lot of `unwrap()`; since method_defs are provided by us, they should
// all succeed.
Comment on lines +142 to +144
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a huge fan of taking Vec<TokenStream> as an argument and all the corresponding unwrap it subsequently leads to.

I think an alternative might be to take Vec<(syn::ItemFn, PyMethod, &SlotDef)> and have the machinery to just produce the slots in here. But TBH that would probably lead to quite a large refactoring so I think it's better here to just focus on getting the end-user output as desired and then we can refactor internals of PyO3 for compile speed to our hearts' content later! 😂

let ty: syn::Type = syn::parse_quote!(#cls);

let mut method_defs: Vec<_> = method_defs
.into_iter()
.map(|token| syn::parse2::<syn::ImplItemMethod>(token).unwrap())
.collect();

let mut proto_impls = Vec::new();

for meth in &mut method_defs {
let options = PyFunctionOptions::from_attrs(&mut meth.attrs).unwrap();
match pymethod::gen_py_method(&ty, &mut meth.sig, &mut meth.attrs, options).unwrap() {
GeneratedPyMethod::Proto(token_stream) => {
let attrs = get_cfg_attributes(&meth.attrs);
proto_impls.push(quote!(#(#attrs)* #token_stream))
}
GeneratedPyMethod::SlotTraitImpl(..) => {
panic!("SlotFragment methods cannot have default implementation!")
}
GeneratedPyMethod::Method(_) | GeneratedPyMethod::TraitImpl(_) => {
panic!("Only protocol methods can have default implementation!")
}
}
}

quote! {
impl #cls {
#(#method_defs)*
}
jovenlin0527 marked this conversation as resolved.
Show resolved Hide resolved
impl ::pyo3::class::impl_::PyClassDefaultSlots<#cls>
for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
&[#(#proto_impls),*]
}
}
}
}

fn impl_py_methods(ty: &syn::Type, methods: Vec<TokenStream>) -> TokenStream {
quote! {
impl ::pyo3::class::impl_::PyMethods<#ty>
Expand Down
3 changes: 3 additions & 0 deletions src/class/impl_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,9 @@ slots_trait!(PyAsyncProtocolSlots, async_protocol_slots);
slots_trait!(PySequenceProtocolSlots, sequence_protocol_slots);
slots_trait!(PyBufferProtocolSlots, buffer_protocol_slots);

// slots that PyO3 implements by default, but can be overidden by the users.
slots_trait!(PyClassDefaultSlots, py_class_default_slots);

// Protocol slots from #[pymethods] if not using inventory.
#[cfg(not(feature = "multiple-pymethods"))]
slots_trait!(PyMethodsProtocolSlots, methods_protocol_slots);
Expand Down
41 changes: 41 additions & 0 deletions tests/test_default_impls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use pyo3::prelude::*;

mod common;

// Test default generated __repr__.
#[pyclass]
enum TestDefaultRepr {
Var,
}

#[test]
fn test_default_slot_exists() {
Python::with_gil(|py| {
let test_object = Py::new(py, TestDefaultRepr::Var).unwrap();
py_assert!(
py,
test_object,
"repr(test_object) == 'TestDefaultRepr.Var'"
);
})
}

#[pyclass]
enum OverrideSlot {
Var,
}

#[pymethods]
impl OverrideSlot {
fn __repr__(&self) -> &str {
"overriden"
}
}

#[test]
fn test_override_slot() {
Python::with_gil(|py| {
let test_object = Py::new(py, OverrideSlot::Var).unwrap();
py_assert!(py, test_object, "repr(test_object) == 'overriden'");
})
}
10 changes: 10 additions & 0 deletions tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,13 @@ fn test_enum_arg() {

py_run!(py, f mynum, "f(mynum.Variant)")
}

#[test]
fn test_default_repr_correct() {
Python::with_gil(|py| {
let var1 = Py::new(py, MyEnum::Variant).unwrap();
let var2 = Py::new(py, MyEnum::OtherVariant).unwrap();
py_assert!(py, var1, "repr(var1) == 'MyEnum.Variant'");
py_assert!(py, var2, "repr(var2) == 'MyEnum.OtherVariant'");
})
}