-
Notifications
You must be signed in to change notification settings - Fork 783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Derive Macro for IntoPyDict #3350
Changes from all commits
f64aa86
c35b598
077e470
c87345d
533992d
b77e9e1
2362538
b86bf14
66a3442
4056990
c4bf9db
774864e
b9215d4
3d8e257
fd1d6f4
40c7636
6e5601b
e0264c3
509942e
25f28ca
70468f3
a047c3d
afb6f12
67891b6
31be92d
e6f18ba
b26e64d
7e08475
33f09a3
695a6ad
04dead8
4f99a11
ed271fb
7d3c2f5
6ff9870
0ef5f7c
0641561
ef0d86b
1c88606
3065ba6
14f8b01
8a933b7
be9a6a6
81bd87d
31b4e7f
08366df
d014295
dd594b8
3f35396
896f483
8385b1b
f89ba6a
1415231
ee19677
2adfbee
20c90c4
d8bc842
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Added derive macro for ```IntoPyDict``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,366 @@ | ||
use quote::{quote, TokenStreamExt}; | ||
use std::{collections::HashMap, ops::AddAssign}; | ||
|
||
use proc_macro2::{Span, TokenStream}; | ||
use syn::{ | ||
parse::{Parse, ParseStream}, | ||
DeriveInput, Error, Generics, | ||
}; | ||
|
||
const COL_NAMES: [&str; 8] = [ | ||
"BTreeSet", | ||
"BinaryHeap", | ||
"Vec", | ||
"HashSet", | ||
"LinkedList", | ||
"VecDeque", | ||
"BTreeMap", | ||
"HashMap", | ||
]; | ||
|
||
#[derive(Debug, Clone)] | ||
enum Pyo3Type { | ||
Primitive, | ||
NonPrimitive, | ||
CollectionSing(Box<crate::intopydict::Pyo3Type>), | ||
// Map( | ||
// Box<crate::intopydict::Pyo3Type>, | ||
// Box<crate::intopydict::Pyo3Type>, | ||
// ), | ||
} | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct Pyo3DictField { | ||
name: String, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably be an |
||
attr_type: Pyo3Type, | ||
attr_name: Option<String>, | ||
} | ||
|
||
impl Pyo3DictField { | ||
pub fn new(name: String, type_: &str, span: Span, attr_name: Option<String>) -> Self { | ||
Self { | ||
name, | ||
attr_type: Self::check_primitive(type_, span), | ||
attr_name, | ||
} | ||
} | ||
|
||
fn check_primitive(attr_type: &str, span: Span) -> Pyo3Type { | ||
for collection in COL_NAMES { | ||
if attr_type.starts_with(collection) { | ||
let attr_type = attr_type.replace('>', ""); | ||
let attr_list: Vec<&str> = attr_type.split('<').collect(); | ||
let out = Self::handle_collection(&attr_list, span); | ||
|
||
return out.unwrap(); | ||
} | ||
} | ||
|
||
match attr_type { | ||
"i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | ||
| "u128" | "usize" | "f32" | "f64" | "char" | "bool" | "&str" | "String" => { | ||
Pyo3Type::Primitive | ||
} | ||
_ => Pyo3Type::NonPrimitive, | ||
} | ||
} | ||
|
||
fn handle_collection(attr_type: &[&str], span: Span) -> syn::Result<Pyo3Type> { | ||
match attr_type[0] { | ||
"BTreeSet" | "BinaryHeap" | "Vec" | "HashSet" | "LinkedList" | "VecDeque" => { | ||
Ok(Pyo3Type::CollectionSing(Box::new( | ||
Self::handle_collection(&attr_type[1..], span).unwrap(), | ||
))) | ||
} | ||
"BTreeMap" | "HashMap" => { | ||
Err(Error::new(span, "Derive currently doesn't support map types. Please use a custom implementation for structs using a map type like HashMap or BTreeMap")) | ||
} | ||
"i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | ||
| "u128" | "usize" | "f32" | "f64" | "char" | "bool" | "&str" | "String" => { | ||
Ok(Pyo3Type::Primitive) | ||
} | ||
_ => Ok(Pyo3Type::NonPrimitive), | ||
} | ||
} | ||
} | ||
|
||
impl Parse for Pyo3Collection { | ||
fn parse(input: ParseStream<'_>) -> syn::Result<Self> { | ||
let tok_stream: TokenStream = input.parse()?; | ||
let mut binding = tok_stream | ||
.to_string() | ||
.as_str() | ||
.replace(|c| c == ' ' || c == '{' || c == '}' || c == '\n', ""); | ||
|
||
if !binding.contains(':') { | ||
return Ok(Pyo3Collection(Vec::new())); | ||
} | ||
|
||
if binding.as_bytes()[binding.len() - 1] as char != ',' { | ||
binding.push(','); | ||
} | ||
|
||
let name_map = split_struct(binding); | ||
|
||
let mut field_collection: Vec<Pyo3DictField> = Vec::new(); | ||
|
||
for (field_name, (field_val, dict_key)) in &name_map { | ||
field_collection.push(Pyo3DictField::new( | ||
field_name.to_string(), | ||
field_val, | ||
input.span(), | ||
dict_key.clone(), | ||
)) | ||
} | ||
|
||
Ok(Pyo3Collection(field_collection)) | ||
} | ||
} | ||
|
||
fn split_struct(binding: String) -> HashMap<String, (String, Option<String>)> { | ||
let mut stack: Vec<char> = Vec::new(); | ||
let mut start = 0; | ||
let binding = binding.replace('\n', ""); | ||
let mut name_map: HashMap<String, (String, Option<String>)> = HashMap::new(); | ||
|
||
for (i, char_val) in binding.chars().enumerate() { | ||
if char_val == ',' && stack.is_empty() { | ||
if binding[start..i].starts_with('#') { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you might want to take a look at |
||
let new_name = get_new_name(binding.clone(), start, i); | ||
let var_string = &binding[start..i].split(']').collect::<Vec<&str>>()[1]; | ||
let info_parsed = var_string.split(':').collect::<Vec<&str>>(); | ||
name_map.insert( | ||
info_parsed[0].to_string(), | ||
(info_parsed[1].to_string(), Some(new_name)), | ||
); | ||
} else { | ||
let info_parsed = binding[start..i].split(':').collect::<Vec<&str>>(); | ||
name_map.insert( | ||
info_parsed[0].to_string(), | ||
(info_parsed[1].to_string(), None), | ||
); | ||
} | ||
start = i + 1; | ||
} else if i == binding.len() - 1 { | ||
let info_parsed = binding[start..].split(':').collect::<Vec<&str>>(); | ||
name_map.insert( | ||
info_parsed[0].to_string(), | ||
(info_parsed[1].to_string(), None), | ||
); | ||
} | ||
|
||
if char_val == '<' || char_val == '(' { | ||
stack.push(char_val); | ||
} | ||
|
||
if char_val == '>' || char_val == ')' { | ||
stack.pop(); | ||
} | ||
} | ||
|
||
// if !name_map.is_empty() { | ||
// let mut last = tok_split.last().unwrap().clone(); | ||
// for i in stack { | ||
// last.push(i) | ||
// } | ||
// let len = tok_split.len(); | ||
// tok_split[len - 1] = last; | ||
// } | ||
|
||
name_map | ||
} | ||
|
||
fn get_new_name(binding: String, start: usize, i: usize) -> String { | ||
let fragments: Vec<&str> = binding[start..i].split("name=").collect(); | ||
let mut quote_count = 0; | ||
let mut start = 0; | ||
for (j, char_val_inner) in fragments[1].chars().enumerate() { | ||
if char_val_inner == '"' { | ||
quote_count += 1; | ||
|
||
if quote_count == 1 { | ||
start = j + 1; | ||
} | ||
} | ||
|
||
if quote_count == 2 { | ||
return fragments[1][start..j].to_string(); | ||
} | ||
} | ||
|
||
String::new() | ||
} | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct Pyo3Collection(pub Vec<Pyo3DictField>); | ||
|
||
impl AddAssign for Pyo3Collection { | ||
fn add_assign(&mut self, rhs: Self) { | ||
self.0.extend(rhs.0); | ||
} | ||
} | ||
|
||
pub fn build_derive_into_pydict(dict_fields: Pyo3Collection) -> TokenStream { | ||
let mut body = quote! { | ||
let mut pydict = pyo3::types::PyDict::new(py); | ||
}; | ||
|
||
for field in &dict_fields.0 { | ||
let ident: &String; | ||
if let Some(ref val) = field.attr_name { | ||
ident = val; | ||
} else { | ||
ident = &field.name; | ||
} | ||
let ident_tok: TokenStream = field.name.parse().unwrap(); | ||
if !ident.is_empty() { | ||
match_tok(field, &mut body, ident, ident_tok); | ||
} | ||
} | ||
body.append_all(quote! { | ||
return pydict; | ||
}); | ||
|
||
body | ||
} | ||
|
||
fn match_tok( | ||
field: &Pyo3DictField, | ||
body: &mut TokenStream, | ||
ident: &String, | ||
ident_tok: TokenStream, | ||
) { | ||
match field.attr_type { | ||
Pyo3Type::Primitive => { | ||
body.append_all(quote! { | ||
pydict.set_item(#ident, self.#ident_tok).expect("Bad element in set_item"); | ||
}); | ||
} | ||
Pyo3Type::NonPrimitive => { | ||
body.append_all(quote! { | ||
pydict.set_item(#ident, self.#ident_tok.into_py_dict(py)).expect("Bad element in set_item"); | ||
}); | ||
} | ||
Pyo3Type::CollectionSing(ref collection) => { | ||
let non_class_ident = ident.replace('.', "_"); | ||
body.append_all(handle_single_collection_code_gen( | ||
collection, | ||
&format!("self.{}", ident_tok), | ||
&non_class_ident, | ||
0, | ||
)); | ||
|
||
let ls_name: TokenStream = format!("pylist0{}", ident).parse().unwrap(); | ||
body.append_all(quote! { | ||
pydict.set_item(#ident, #ls_name).expect("Bad element in set_item"); | ||
}); | ||
} // Pyo3Type::Map(ref key, ref val) => { | ||
// if let Pyo3Type::NonPrimitive = key.as_ref() { | ||
// panic!("Key must be a primitive type to be derived into a dict. If you want to use non primitive as a dict key, use a custom implementation"); | ||
// } | ||
|
||
// match val.as_ref() { | ||
// Pyo3Type::Primitive => todo!(), | ||
// Pyo3Type::NonPrimitive => todo!(), | ||
// Pyo3Type::CollectionSing(_) => todo!(), | ||
// Pyo3Type::Map(_, _) => todo!(), | ||
// } | ||
// } | ||
}; | ||
} | ||
|
||
fn handle_single_collection_code_gen( | ||
py_type: &Pyo3Type, | ||
ident: &str, | ||
non_class_ident: &str, | ||
counter: usize, | ||
) -> TokenStream { | ||
let curr_pylist: TokenStream = format!("pylist{}{}", counter, non_class_ident) | ||
.parse() | ||
.unwrap(); | ||
let next_pylist: TokenStream = format!("pylist{}{}", counter + 1, non_class_ident) | ||
.parse() | ||
.unwrap(); | ||
let ident_tok: TokenStream = ident.parse().unwrap(); | ||
match py_type { | ||
Pyo3Type::Primitive => { | ||
quote! { | ||
let mut #curr_pylist = pyo3::types::PyList::empty(py); | ||
for i in #ident_tok.into_iter() { | ||
#curr_pylist.append(i).expect("Bad element in set_item"); | ||
}; | ||
} | ||
} | ||
Pyo3Type::NonPrimitive => { | ||
quote! { | ||
let mut #curr_pylist = pyo3::types::PyList::empty(py); | ||
for i in #ident_tok.into_iter() { | ||
#curr_pylist.append(i.into_py_dict(py)).expect("Bad element in set_item"); | ||
}; | ||
} | ||
} | ||
Pyo3Type::CollectionSing(coll) => { | ||
let body = | ||
handle_single_collection_code_gen(coll.as_ref(), "i", non_class_ident, counter + 1); | ||
quote! { | ||
let mut #curr_pylist = pyo3::types::PyList::empty(py); | ||
for i in #ident_tok.into_iter(){ | ||
#body | ||
#curr_pylist.append(#next_pylist).expect("Bad element in set_item"); | ||
}; | ||
} | ||
} | ||
} | ||
} | ||
|
||
pub fn parse_generics(generics: &Generics) -> String { | ||
if !generics.params.is_empty() { | ||
let mut generics_parsed = "<".to_string(); | ||
|
||
for param in &generics.params { | ||
match param { | ||
syn::GenericParam::Lifetime(lt) => { | ||
generics_parsed += ("'".to_string() + <.lifetime.ident.to_string()).as_str() | ||
} | ||
syn::GenericParam::Type(generic_type) => { | ||
generics_parsed += generic_type.ident.to_string().as_str() | ||
} | ||
syn::GenericParam::Const(const_type) => { | ||
generics_parsed += | ||
("const".to_string() + const_type.ident.to_string().as_str()).as_str() | ||
} | ||
} | ||
|
||
generics_parsed += ","; | ||
} | ||
|
||
generics_parsed = generics_parsed[0..generics_parsed.len() - 1].to_string(); | ||
generics_parsed += ">"; | ||
generics_parsed | ||
} else { | ||
String::new() | ||
} | ||
} | ||
Comment on lines
+316
to
+343
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! Ill keep that in mind |
||
|
||
pub fn check_type(input: &DeriveInput) -> syn::Result<()> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function should not exist. Instead match on the pub fn handle_deriveinput(input: &DeriveInput) -> syn::Result<???> {
match input.data {
syn::Data::Struct(strukt @ syn::Datastruct { fields: syn::Fields::Named(_), ..}) => {
handle_struct(strukt)
}
_ => Err(syn::Error::new(
// span,
"IntoPyDict derive is only possible for structs with named fields",
))
}
}
pub fn handle_struct(input: &syn::DataStruct) -> syn::Result<???>{
// ...
} Instead of having a bunch of "check" functions, parse, don't validate. Did you spot the bug/oversight in your code that this code does not have? What about structs without fields? 🙃 |
||
match input.data { | ||
syn::Data::Struct(ref info) => { | ||
if let syn::Fields::Unnamed(_) = info.fields { | ||
return Err(syn::Error::new( | ||
info.struct_token.span, | ||
"No support for tuple structs currently. Please write your own implementation for the struct.", | ||
)); | ||
} | ||
|
||
Ok(()) | ||
} | ||
syn::Data::Enum(ref info) => Err(syn::Error::new( | ||
info.brace_token.span.close(), | ||
"No support for enums currently. Please write your own implementation for the enum.", | ||
)), | ||
syn::Data::Union(ref info) => Err(syn::Error::new( | ||
info.union_token.span, | ||
"No support for unions currently. Please write your own implementation for the union.", | ||
)), | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel quite strongly that instead of special-casing this macro on names of types the implementation should emit code based on traits. In this case, if we go with the proposal of deep
IntoPyDict
, we should use autoref-specialization which prefersIntoPyDict
if implemented and otherwise falls back onIntoPy<PyObject>
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this apply to collections as well (ie Vecs, Sets etc)? These would get stored as a PyObject instead of a PyList?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And how do you recommend I use autoref specialization? I don't have much experience with it. Is there a crate which makes it easier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A separate crate usually doesn't help because abstracting the pattern away is often more complex than just applying it.
In this case, the general technique is described at https://github.com/dtolnay/case-studies/blob/master/autoref-specialization/README.md and you can find an example of its application in PyO3's existing macros by looking for usages of the type
PyClassImplCollector
defined atpyo3/src/impl_/pyclass.rs
Line 109 in ae982b8
and for example its relation to the trait
PyClassNewTextSignature
defined atpyo3/src/impl_/pyclass.rs
Line 995 in ae982b8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I add the specialization feature as a crate attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No; the specialization feature is incomplete and unsound. Autoref specialization is a technique which works on stable Rust to get a specialization-like result for some limited cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be some issues with reference moving when I try to do autoref specialization. Sorry it took me a while for this, I was finishing up an internship