diff --git a/pyo3-macros-backend/src/dict.rs b/pyo3-macros-backend/src/dict.rs new file mode 100644 index 00000000000..2d248b0aac3 --- /dev/null +++ b/pyo3-macros-backend/src/dict.rs @@ -0,0 +1,73 @@ +use proc_macro2::{Ident, TokenStream, TokenTree}; +use quote::{quote, ToTokens}; +use std::iter::FromIterator; +use syn::parse::{Parse, ParseBuffer, ParseStream}; +use syn::punctuated::Punctuated; +use syn::Token; +use syn::{braced, Expr}; + +#[derive(Debug)] +pub struct PyDictLiteral { + pub py: Ident, + pub items: Vec, +} + +#[derive(Debug)] +pub struct KeyValue { + key: syn::Expr, + value: syn::Expr, +} + +#[derive(Debug)] +struct Key(syn::Expr); + +impl Parse for Key { + fn parse(input: ParseStream) -> syn::Result { + let mut tokens = vec![]; + + while !input.peek(Token![:]) || input.peek(Token![::]) { + let tt = input.parse::()?; + tokens.push(tt); + } + let stream = TokenStream::from_iter(tokens.into_iter()); + + let expr = syn::parse2::(stream)?; + Ok(Self(expr)) + } +} + +impl Parse for KeyValue { + fn parse(input: ParseStream) -> syn::Result { + let key: Key = input.parse()?; + let _sep: Token![:] = input.parse()?; + let value: syn::Expr = input.parse()?; + + Ok(Self { key: key.0, value }) + } +} + +impl Parse for PyDictLiteral { + fn parse(input: ParseStream) -> syn::Result { + let py: Ident = input.parse()?; + let _arrow: Token![=>] = input.parse()?; + + let body: ParseBuffer; + braced!(body in input); + + let items: Punctuated = Punctuated::parse_terminated(&body)?; + + Ok(Self { + py, + items: items.into_iter().collect(), + }) + } +} + +impl ToTokens for KeyValue { + fn to_tokens(&self, tokens: &mut TokenStream) { + let key = &self.key; + let value = &self.value; + let ts = quote! {(#key, #value)}; + tokens.extend(ts); + } +} diff --git a/pyo3-macros-backend/src/lib.rs b/pyo3-macros-backend/src/lib.rs index 69fc24d250a..80025046a79 100644 --- a/pyo3-macros-backend/src/lib.rs +++ b/pyo3-macros-backend/src/lib.rs @@ -11,6 +11,7 @@ mod utils; mod attributes; mod defs; mod deprecations; +mod dict; mod from_pyobject; mod konst; mod method; @@ -23,6 +24,7 @@ mod pyimpl; mod pymethod; mod pyproto; +pub use dict::PyDictLiteral; pub use from_pyobject::build_derive_from_pyobject; pub use module::{process_functions_in_module, py_init, PyModuleOptions}; pub use pyclass::{build_py_class, build_py_enum, PyClassArgs}; diff --git a/pyo3-macros/src/lib.rs b/pyo3-macros/src/lib.rs index a6f6725c6e5..30089a400a8 100644 --- a/pyo3-macros/src/lib.rs +++ b/pyo3-macros/src/lib.rs @@ -9,7 +9,7 @@ use proc_macro::TokenStream; use pyo3_macros_backend::{ build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods, build_py_proto, get_doc, process_functions_in_module, py_init, PyClassArgs, PyClassMethodsType, - PyFunctionOptions, PyModuleOptions, + PyDictLiteral, PyFunctionOptions, PyModuleOptions, }; use quote::quote; use syn::{parse::Nothing, parse_macro_input}; @@ -199,6 +199,21 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream { .into() } +#[proc_macro] +pub fn py_dict(input: proc_macro::TokenStream) -> TokenStream { + let PyDictLiteral { items, py } = parse_macro_input!(input as PyDictLiteral); + let stream = quote! { + (|| { + use pyo3 as _pyo3; + + let dict = _pyo3::types::PyDict::new(#py); + #(dict.set_item#items?;)* + _pyo3::PyResult::Ok(dict) + })() + }; + stream.into() +} + fn pyclass_impl( attrs: TokenStream, mut ast: syn::ItemStruct, diff --git a/src/lib.rs b/src/lib.rs index e8ebffcb8c7..691756f29fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -355,7 +355,7 @@ pub mod proc_macro { } #[cfg(feature = "macros")] -pub use pyo3_macros::{pyclass, pyfunction, pymethods, pymodule, pyproto, FromPyObject}; +pub use pyo3_macros::{py_dict, pyclass, pyfunction, pymethods, pymodule, pyproto, FromPyObject}; #[macro_use] mod macros; diff --git a/src/types/macros.rs b/src/types/macros.rs new file mode 100644 index 00000000000..6d34bb7c511 --- /dev/null +++ b/src/types/macros.rs @@ -0,0 +1,144 @@ +#[cfg(feature = "macros")] +pub use pyo3_macros::py_dict; + +#[doc(hidden)] +#[macro_export] +macro_rules! py_object_vec { + ($py:ident, [$($item:expr),+]) => {{ + let items_vec: Vec<$crate::PyObject> = + vec![$($crate::conversion::IntoPy::into_py($item, $py)),+]; + items_vec + }}; +} + +#[macro_export] +macro_rules! py_list { + ($py:ident, [$($items:expr),+]) => {{ + let items_vec = $crate::py_object_vec!($py, [$($items),+]); + $crate::types::list::PyList::new($py, items_vec) + }}; +} + +#[macro_export] +macro_rules! py_tuple { + ($py:ident, ($($items:expr),+)) => {{ + let items_vec = $crate::py_object_vec!($py, [$($items),+]); + $crate::types::PyTuple::new($py, items_vec) + }}; +} + +#[macro_export] +macro_rules! py_set { + ($py:ident, {$($items:expr),+}) => {{ + let items_vec = $crate::py_object_vec!($py, [$($items),+]); + $crate::types::set::PySet::new($py, items_vec.as_slice()) + }}; +} + +#[macro_export] +macro_rules! py_frozenset { + ($py:ident, {$($items:expr),+}) => {{ + let items_vec = $crate::py_object_vec!($py, [$($items),+]); + $crate::types::set::PyFrozenSet::new($py, items_vec.as_slice()) + }}; +} + +#[cfg(test)] +mod test { + use crate::types::PyFrozenSet; + use crate::Python; + + #[test] + fn test_list_macro() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let single_item_list = py_list!(py, ["elem"]); + assert_eq!( + "elem", + single_item_list + .get_item(0) + .expect("failed to get item") + .extract::<&str>() + .unwrap() + ); + + let multi_item_list = py_list!(py, ["elem1", "elem2", 3, 4]); + + assert_eq!( + "['elem1', 'elem2', 3, 4]", + multi_item_list.str().unwrap().extract::<&str>().unwrap() + ); + } + + #[test] + fn test_tuple_macro() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let single_item_tuple = py_tuple!(py, ("elem")); + assert_eq!( + "elem", + single_item_tuple + .get_item(0) + .expect("failed to get item") + .extract::<&str>() + .unwrap() + ); + + let multi_item_tuple = py_tuple!(py, ("elem1", "elem2", 3, 4)); + + assert_eq!( + "('elem1', 'elem2', 3, 4)", + multi_item_tuple.str().unwrap().extract::<&str>().unwrap() + ); + } + + #[test] + fn test_set_macro() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let set = py_set!(py, { "set_elem" }).expect("failed to create set"); + + assert!(set.contains("set_elem").unwrap()); + + set.call_method1( + "update", + py_tuple!( + py, + (py_set!(py, {"new_elem1", "new_elem2", "set_elem"}).unwrap()) + ), + ) + .expect("failed to update set"); + + for &expected_elem in &["set_elem", "new_elem1", "new_elem2"] { + assert!(set.contains(expected_elem).unwrap()); + } + } + + #[test] + fn test_frozenset_macro() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let frozenset = py_frozenset!(py, { "set_elem" }).expect("failed to create frozenset"); + + assert!(frozenset.contains("set_elem").unwrap()); + + let intersection = frozenset + .call_method1( + "intersection", + py_tuple!( + py, + (py_set!(py, {"new_elem1", "new_elem2", "set_elem"}).unwrap()) + ), + ) + .expect("failed to call intersection()") + .downcast::() + .expect("failed to downcast to FrozenSet"); + + assert_eq!(1, intersection.len()); + assert!(intersection.contains("set_elem").unwrap()); + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index cbbeafb51fe..071601feec8 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -233,6 +233,7 @@ mod floatob; mod function; mod iterator; mod list; +mod macros; mod mapping; mod module; mod num; diff --git a/tests/test_literals.rs b/tests/test_literals.rs new file mode 100644 index 00000000000..4c16dfee18e --- /dev/null +++ b/tests/test_literals.rs @@ -0,0 +1,52 @@ +#![cfg(feature = "macros")] + +use pyo3::prelude::*; +use pyo3::{py_dict, py_run, py_tuple}; + +#[test] +fn test_dict_literal() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let dict = py_dict!(py => {"key": "value"}).expect("failed to create dict"); + assert_eq!( + "value", + dict.get_item("key").unwrap().extract::().unwrap() + ); + + let value = "value"; + let multi_elem_dict = + py_dict!(py => {"key1": value, 143: "abcde"}).expect("failed to create dict"); + assert_eq!( + "value", + multi_elem_dict + .get_item("key1") + .unwrap() + .extract::<&str>() + .unwrap() + ); + assert_eq!( + "abcde", + multi_elem_dict + .get_item(143) + .unwrap() + .extract::<&str>() + .unwrap() + ); + + let keys = &["key1", "key2"]; + + let expr_dict = py_dict!(py => { + keys[0]: "value1", + keys[1]: "value2", + 3-7: py_tuple!(py, ("elem1", "elem2", 3)), + "KeY".to_lowercase(): 100 * 2, + }) + .expect("failed to create dict"); + + py_run!( + py, + expr_dict, + "assert expr_dict == {'key1': 'value1', 'key2': 'value2', -4: ('elem1', 'elem2', 3), 'key': 200}" + ); +}