diff --git a/newsfragments/3350.added.md b/newsfragments/3350.added.md new file mode 100644 index 00000000000..7e3a566fe06 --- /dev/null +++ b/newsfragments/3350.added.md @@ -0,0 +1 @@ +Added derive macro for ```IntoPyDict``` \ No newline at end of file diff --git a/pyo3-macros-backend/src/intopydict.rs b/pyo3-macros-backend/src/intopydict.rs new file mode 100644 index 00000000000..5ab2a410384 --- /dev/null +++ b/pyo3-macros-backend/src/intopydict.rs @@ -0,0 +1,366 @@ +use quote::{quote, TokenStreamExt}; +use std::{collections::HashMap, ops::AddAssign}; + +use proc_macro2::{Span, TokenStream}; +use syn::{ + parse::{Parse, ParseStream}, + DeriveInput, Error, Generics, +}; + +const COL_NAMES: [&str; 8] = [ + "BTreeSet", + "BinaryHeap", + "Vec", + "HashSet", + "LinkedList", + "VecDeque", + "BTreeMap", + "HashMap", +]; + +#[derive(Debug, Clone)] +enum Pyo3Type { + Primitive, + NonPrimitive, + CollectionSing(Box), + // Map( + // Box, + // Box, + // ), +} + +#[derive(Debug, Clone)] +pub struct Pyo3DictField { + name: String, + attr_type: Pyo3Type, + attr_name: Option, +} + +impl Pyo3DictField { + pub fn new(name: String, type_: &str, span: Span, attr_name: Option) -> Self { + Self { + name, + attr_type: Self::check_primitive(type_, span), + attr_name, + } + } + + fn check_primitive(attr_type: &str, span: Span) -> Pyo3Type { + for collection in COL_NAMES { + if attr_type.starts_with(collection) { + let attr_type = attr_type.replace('>', ""); + let attr_list: Vec<&str> = attr_type.split('<').collect(); + let out = Self::handle_collection(&attr_list, span); + + return out.unwrap(); + } + } + + match attr_type { + "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" + | "u128" | "usize" | "f32" | "f64" | "char" | "bool" | "&str" | "String" => { + Pyo3Type::Primitive + } + _ => Pyo3Type::NonPrimitive, + } + } + + fn handle_collection(attr_type: &[&str], span: Span) -> syn::Result { + match attr_type[0] { + "BTreeSet" | "BinaryHeap" | "Vec" | "HashSet" | "LinkedList" | "VecDeque" => { + Ok(Pyo3Type::CollectionSing(Box::new( + Self::handle_collection(&attr_type[1..], span).unwrap(), + ))) + } + "BTreeMap" | "HashMap" => { + Err(Error::new(span, "Derive currently doesn't support map types. Please use a custom implementation for structs using a map type like HashMap or BTreeMap")) + } + "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" + | "u128" | "usize" | "f32" | "f64" | "char" | "bool" | "&str" | "String" => { + Ok(Pyo3Type::Primitive) + } + _ => Ok(Pyo3Type::NonPrimitive), + } + } +} + +impl Parse for Pyo3Collection { + fn parse(input: ParseStream<'_>) -> syn::Result { + let tok_stream: TokenStream = input.parse()?; + let mut binding = tok_stream + .to_string() + .as_str() + .replace(|c| c == ' ' || c == '{' || c == '}' || c == '\n', ""); + + if !binding.contains(':') { + return Ok(Pyo3Collection(Vec::new())); + } + + if binding.as_bytes()[binding.len() - 1] as char != ',' { + binding.push(','); + } + + let name_map = split_struct(binding); + + let mut field_collection: Vec = Vec::new(); + + for (field_name, (field_val, dict_key)) in &name_map { + field_collection.push(Pyo3DictField::new( + field_name.to_string(), + field_val, + input.span(), + dict_key.clone(), + )) + } + + Ok(Pyo3Collection(field_collection)) + } +} + +fn split_struct(binding: String) -> HashMap)> { + let mut stack: Vec = Vec::new(); + let mut start = 0; + let binding = binding.replace('\n', ""); + let mut name_map: HashMap)> = HashMap::new(); + + for (i, char_val) in binding.chars().enumerate() { + if char_val == ',' && stack.is_empty() { + if binding[start..i].starts_with('#') { + let new_name = get_new_name(binding.clone(), start, i); + let var_string = &binding[start..i].split(']').collect::>()[1]; + let info_parsed = var_string.split(':').collect::>(); + name_map.insert( + info_parsed[0].to_string(), + (info_parsed[1].to_string(), Some(new_name)), + ); + } else { + let info_parsed = binding[start..i].split(':').collect::>(); + name_map.insert( + info_parsed[0].to_string(), + (info_parsed[1].to_string(), None), + ); + } + start = i + 1; + } else if i == binding.len() - 1 { + let info_parsed = binding[start..].split(':').collect::>(); + name_map.insert( + info_parsed[0].to_string(), + (info_parsed[1].to_string(), None), + ); + } + + if char_val == '<' || char_val == '(' { + stack.push(char_val); + } + + if char_val == '>' || char_val == ')' { + stack.pop(); + } + } + + // if !name_map.is_empty() { + // let mut last = tok_split.last().unwrap().clone(); + // for i in stack { + // last.push(i) + // } + // let len = tok_split.len(); + // tok_split[len - 1] = last; + // } + + name_map +} + +fn get_new_name(binding: String, start: usize, i: usize) -> String { + let fragments: Vec<&str> = binding[start..i].split("name=").collect(); + let mut quote_count = 0; + let mut start = 0; + for (j, char_val_inner) in fragments[1].chars().enumerate() { + if char_val_inner == '"' { + quote_count += 1; + + if quote_count == 1 { + start = j + 1; + } + } + + if quote_count == 2 { + return fragments[1][start..j].to_string(); + } + } + + String::new() +} + +#[derive(Debug, Clone)] +pub struct Pyo3Collection(pub Vec); + +impl AddAssign for Pyo3Collection { + fn add_assign(&mut self, rhs: Self) { + self.0.extend(rhs.0); + } +} + +pub fn build_derive_into_pydict(dict_fields: Pyo3Collection) -> TokenStream { + let mut body = quote! { + let mut pydict = pyo3::types::PyDict::new(py); + }; + + for field in &dict_fields.0 { + let ident: &String; + if let Some(ref val) = field.attr_name { + ident = val; + } else { + ident = &field.name; + } + let ident_tok: TokenStream = field.name.parse().unwrap(); + if !ident.is_empty() { + match_tok(field, &mut body, ident, ident_tok); + } + } + body.append_all(quote! { + return pydict; + }); + + body +} + +fn match_tok( + field: &Pyo3DictField, + body: &mut TokenStream, + ident: &String, + ident_tok: TokenStream, +) { + match field.attr_type { + Pyo3Type::Primitive => { + body.append_all(quote! { + pydict.set_item(#ident, self.#ident_tok).expect("Bad element in set_item"); + }); + } + Pyo3Type::NonPrimitive => { + body.append_all(quote! { + pydict.set_item(#ident, self.#ident_tok.into_py_dict(py)).expect("Bad element in set_item"); + }); + } + Pyo3Type::CollectionSing(ref collection) => { + let non_class_ident = ident.replace('.', "_"); + body.append_all(handle_single_collection_code_gen( + collection, + &format!("self.{}", ident_tok), + &non_class_ident, + 0, + )); + + let ls_name: TokenStream = format!("pylist0{}", ident).parse().unwrap(); + body.append_all(quote! { + pydict.set_item(#ident, #ls_name).expect("Bad element in set_item"); + }); + } // Pyo3Type::Map(ref key, ref val) => { + // if let Pyo3Type::NonPrimitive = key.as_ref() { + // panic!("Key must be a primitive type to be derived into a dict. If you want to use non primitive as a dict key, use a custom implementation"); + // } + + // match val.as_ref() { + // Pyo3Type::Primitive => todo!(), + // Pyo3Type::NonPrimitive => todo!(), + // Pyo3Type::CollectionSing(_) => todo!(), + // Pyo3Type::Map(_, _) => todo!(), + // } + // } + }; +} + +fn handle_single_collection_code_gen( + py_type: &Pyo3Type, + ident: &str, + non_class_ident: &str, + counter: usize, +) -> TokenStream { + let curr_pylist: TokenStream = format!("pylist{}{}", counter, non_class_ident) + .parse() + .unwrap(); + let next_pylist: TokenStream = format!("pylist{}{}", counter + 1, non_class_ident) + .parse() + .unwrap(); + let ident_tok: TokenStream = ident.parse().unwrap(); + match py_type { + Pyo3Type::Primitive => { + quote! { + let mut #curr_pylist = pyo3::types::PyList::empty(py); + for i in #ident_tok.into_iter() { + #curr_pylist.append(i).expect("Bad element in set_item"); + }; + } + } + Pyo3Type::NonPrimitive => { + quote! { + let mut #curr_pylist = pyo3::types::PyList::empty(py); + for i in #ident_tok.into_iter() { + #curr_pylist.append(i.into_py_dict(py)).expect("Bad element in set_item"); + }; + } + } + Pyo3Type::CollectionSing(coll) => { + let body = + handle_single_collection_code_gen(coll.as_ref(), "i", non_class_ident, counter + 1); + quote! { + let mut #curr_pylist = pyo3::types::PyList::empty(py); + for i in #ident_tok.into_iter(){ + #body + #curr_pylist.append(#next_pylist).expect("Bad element in set_item"); + }; + } + } + } +} + +pub fn parse_generics(generics: &Generics) -> String { + if !generics.params.is_empty() { + let mut generics_parsed = "<".to_string(); + + for param in &generics.params { + match param { + syn::GenericParam::Lifetime(lt) => { + generics_parsed += ("'".to_string() + <.lifetime.ident.to_string()).as_str() + } + syn::GenericParam::Type(generic_type) => { + generics_parsed += generic_type.ident.to_string().as_str() + } + syn::GenericParam::Const(const_type) => { + generics_parsed += + ("const".to_string() + const_type.ident.to_string().as_str()).as_str() + } + } + + generics_parsed += ","; + } + + generics_parsed = generics_parsed[0..generics_parsed.len() - 1].to_string(); + generics_parsed += ">"; + generics_parsed + } else { + String::new() + } +} + +pub fn check_type(input: &DeriveInput) -> syn::Result<()> { + match input.data { + syn::Data::Struct(ref info) => { + if let syn::Fields::Unnamed(_) = info.fields { + return Err(syn::Error::new( + info.struct_token.span, + "No support for tuple structs currently. Please write your own implementation for the struct.", + )); + } + + Ok(()) + } + syn::Data::Enum(ref info) => Err(syn::Error::new( + info.brace_token.span.close(), + "No support for enums currently. Please write your own implementation for the enum.", + )), + syn::Data::Union(ref info) => Err(syn::Error::new( + info.union_token.span, + "No support for unions currently. Please write your own implementation for the union.", + )), + } +} diff --git a/pyo3-macros-backend/src/lib.rs b/pyo3-macros-backend/src/lib.rs index 745a8471c2b..3eb86db702c 100644 --- a/pyo3-macros-backend/src/lib.rs +++ b/pyo3-macros-backend/src/lib.rs @@ -11,6 +11,7 @@ mod utils; mod attributes; mod deprecations; mod frompyobject; +mod intopydict; mod konst; mod method; mod module; @@ -22,6 +23,7 @@ mod pymethod; mod quotes; pub use frompyobject::build_derive_from_pyobject; +pub use intopydict::{build_derive_into_pydict, check_type, parse_generics, Pyo3Collection}; pub use module::{process_functions_in_module, pymodule_impl, PyModuleOptions}; pub use pyclass::{build_py_class, build_py_enum, PyClassArgs}; pub use pyfunction::{build_py_function, PyFunctionOptions}; diff --git a/pyo3-macros/src/lib.rs b/pyo3-macros/src/lib.rs index 37c7e6e9b99..3ef2568420c 100644 --- a/pyo3-macros/src/lib.rs +++ b/pyo3-macros/src/lib.rs @@ -7,12 +7,13 @@ extern crate proc_macro; use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use pyo3_macros_backend::{ - build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods, - get_doc, process_functions_in_module, pymodule_impl, PyClassArgs, PyClassMethodsType, - PyFunctionOptions, PyModuleOptions, + build_derive_from_pyobject, build_derive_into_pydict, build_py_class, build_py_enum, + build_py_function, build_py_methods, check_type, get_doc, parse_generics, + process_functions_in_module, pymodule_impl, PyClassArgs, PyClassMethodsType, PyFunctionOptions, + PyModuleOptions, Pyo3Collection, }; -use quote::quote; -use syn::{parse::Nothing, parse_macro_input}; +use quote::{quote, ToTokens}; +use syn::{parse::Nothing, parse_macro_input, DeriveInput}; /// A proc macro used to implement Python modules. /// @@ -160,6 +161,38 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream { .into() } +#[proc_macro_derive(IntoPyDict, attributes(pyo3, into_py_dict_ignore))] +pub fn derive_into_pydict(item: TokenStream) -> TokenStream { + let cloned = item.clone(); + let ast = parse_macro_input!(cloned as DeriveInput); + check_type(&ast).unwrap(); + let ident = ast.ident.into_token_stream(); + let clause_wrapped = ast.generics.where_clause.clone(); + let mut where_clause: TokenStream2 = TokenStream2::new(); + let generic_params: TokenStream2 = parse_generics(&ast.generics).parse().unwrap(); + let generics = ast.generics.into_token_stream(); + + if let Some(clause) = clause_wrapped { + where_clause = clause.into_token_stream(); + } + let mut dict_fields: Pyo3Collection = Pyo3Collection(Vec::new()); + for token in item { + let token_stream: syn::__private::TokenStream = token.into(); + dict_fields += parse_macro_input!(token_stream as Pyo3Collection); + } + let body: TokenStream2 = build_derive_into_pydict(dict_fields); + let out = quote! { + + impl #generics pyo3::types::IntoPyDict for #ident #generic_params #where_clause { + fn into_py_dict(self, py: pyo3::Python<'_>) -> &pyo3::types::PyDict { + #body + } + } + }; + + out.into() +} + fn pyclass_impl( attrs: TokenStream, mut ast: syn::ItemStruct, diff --git a/src/prelude.rs b/src/prelude.rs index ca0b0cf38db..2de3ae7d746 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -19,7 +19,7 @@ pub use crate::pyclass_init::PyClassInitializer; pub use crate::types::{PyAny, PyModule}; #[cfg(feature = "macros")] -pub use pyo3_macros::{pyclass, pyfunction, pymethods, pymodule, FromPyObject}; +pub use pyo3_macros::{pyclass, pyfunction, pymethods, pymodule, FromPyObject, IntoPyDict}; #[cfg(feature = "macros")] pub use crate::wrap_pyfunction; diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 8a22d66cdbf..8e9cec71a52 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -37,4 +37,5 @@ fn test_compile_errors() { t.compile_fail("tests/ui/not_send2.rs"); t.compile_fail("tests/ui/get_set_all.rs"); t.compile_fail("tests/ui/traverse.rs"); + t.compile_fail("tests/ui/invalid_intopydict.rs"); } diff --git a/tests/test_intopydict.rs b/tests/test_intopydict.rs new file mode 100644 index 00000000000..63dcf6269b5 --- /dev/null +++ b/tests/test_intopydict.rs @@ -0,0 +1,74 @@ +#![cfg(feature = "macros")] + +use pyo3::{prelude::IntoPyDict, pyclass, types::IntoPyDict}; + +pub trait TestTrait<'a> {} + +#[pyclass] +#[derive(IntoPyDict)] +pub struct TestDict { + x: u8, +} + +#[derive(IntoPyDict)] +pub struct PyClass { + x: u8, + y: TestDict, +} + +#[derive(IntoPyDict, PartialEq, Debug, Clone)] +pub struct Test1 { + x: u8, +} + +#[derive(IntoPyDict, Clone, Debug)] +pub struct Test { + #[pyo3(get, set, name = "hello")] + v: Vec>, + j: Test1, + #[pyo3(get, set, name = "world")] + h: u8, +} + +#[derive(IntoPyDict)] +pub struct TestGeneric { + x: T, + y: TestGenericDouble, +} + +#[derive(IntoPyDict)] +pub struct TestGenericDouble { + x: T, + y: U, +} + +#[derive(IntoPyDict)] +pub struct TestVecPrim { + v: Vec, +} + +#[test] +fn test_into_py_dict_derive() { + let test_struct = Test { + v: vec![vec![Test1 { x: 9 }]], + j: Test1 { x: 10 }, + h: 9, + }; + + let test_generic_struct = TestGeneric { + x: test_struct.clone(), + y: TestGenericDouble { + x: test_struct.clone(), + y: test_struct.clone(), + }, + }; + + pyo3::Python::with_gil(|py| { + let py_dict = test_struct.into_py_dict(py); + let h: u8 = py_dict.get_item("world").unwrap().extract().unwrap(); + + assert_eq!(h, 9); + let pydict = test_generic_struct.into_py_dict(py); + println!("{:?}", pydict); + }); +} diff --git a/tests/ui/invalid_intopydict.rs b/tests/ui/invalid_intopydict.rs new file mode 100644 index 00000000000..3f777a7cfad --- /dev/null +++ b/tests/ui/invalid_intopydict.rs @@ -0,0 +1,11 @@ +use pyo3::prelude::IntoPyDict; + +#[derive(IntoPyDict)] +pub struct TestPyTupleInvalid(u8); + +#[derive(IntoPyDict)] +pub enum TestEnumInvalid { + Variant1 +} + +fn main() {} \ No newline at end of file diff --git a/tests/ui/invalid_intopydict.stderr b/tests/ui/invalid_intopydict.stderr new file mode 100644 index 00000000000..39e5f16220b --- /dev/null +++ b/tests/ui/invalid_intopydict.stderr @@ -0,0 +1,15 @@ +error: proc-macro derive panicked + --> tests/ui/invalid_intopydict.rs:3:10 + | +3 | #[derive(IntoPyDict)] + | ^^^^^^^^^^ + | + = help: message: called `Result::unwrap()` on an `Err` value: Error("No support for tuple structs currently. Please write your own implementation for the struct.") + +error: proc-macro derive panicked + --> tests/ui/invalid_intopydict.rs:6:10 + | +6 | #[derive(IntoPyDict)] + | ^^^^^^^^^^ + | + = help: message: called `Result::unwrap()` on an `Err` value: Error("No support for enums currently. Please write your own implementation for the enum.")