diff --git a/README.md b/README.md index 47cb58f..1684de5 100644 --- a/README.md +++ b/README.md @@ -18,9 +18,11 @@ println!("{}", timestamp); // Timestamp(1701620628123456789) - Conditionally, based on the underlying data type, traits like `Copy`, `Eq`, `Ord`, `Hash` may also be implemented. For primitive data types like `i32` or `bool`, these additional traits will be automatically included. - Numeric types, both integer and floating-point, also implement constants `MIN`, `MAX`, and `ZERO`. Additionally, for floating-point types, `NAN` is implemented. -- **Attributes:** Adding the following attributes to `#[strong_type(...)]` allows for additional features: - - `auto_operators`: Automatically implements relevant arithmetic (for numeric types) or logical (for boolean types) operators. - - `custom_display`: Allows users to manually implement the `Display` trait, providing an alternative to the default display format. +- **Attributes:** + - Adding the following attributes to `#[strong_type(...)]` allows for additional features: + - `auto_operators`: Automatically implements relevant arithmetic (for numeric types) or logical (for boolean types) operators. + - `custom_display`: Allows users to manually implement the `Display` trait, providing an alternative to the default display format. + - Specifying the corresponding primitive types via `#[custom_underlying(...)]` for nested strong types. ## Installation Add `strong-type` to your `Cargo.toml`: @@ -36,6 +38,7 @@ strong-type = "0.7" - Boolean type: `bool` - `char` - `String` + - Strong types of the above types ## Examples #### Creating a named strong type: @@ -142,3 +145,21 @@ impl Display for Second { println!("{}", Second::new(std::f64::consts::E)); // "Second(2.72)" println!("{:?}", Second::new(std::f64::consts::E)); // "Second { value: 2.718281828459045 }" ``` + +#### Nested strong types: + +```rust +#[derive(StrongType)] +#[strong_type(auto_operators)] +struct Dollar(i32); + +#[derive(StrongType)] +#[strong_type(auto_operators)] +#[custom_underlying(i32)] +struct Cash(Dollar); + +#[derive(StrongType)] +#[strong_type(auto_operators)] +#[custom_underlying(i32)] +struct Coin(Cash); +``` \ No newline at end of file diff --git a/strong-type-derive/src/detail/basic_primitive.rs b/strong-type-derive/src/detail/basic_primitive.rs index 59a13d5..01832db 100644 --- a/strong-type-derive/src/detail/basic_primitive.rs +++ b/strong-type-derive/src/detail/basic_primitive.rs @@ -25,3 +25,29 @@ pub(crate) fn implement_basic_primitive(name: &syn::Ident, value_type: &syn::Ide } } } + +pub(crate) fn implement_primitive_accessor( + name: &syn::Ident, + primitive_type: &syn::Ident, +) -> TokenStream { + quote! { + impl #name { + pub fn primitive(&self) -> #primitive_type { + self.value() + } + } + } +} + +pub(crate) fn implement_primitive_accessor_derived( + name: &syn::Ident, + primitive_type: &syn::Ident, +) -> TokenStream { + quote! { + impl #name { + pub fn primitive(&self) -> #primitive_type { + self.0.primitive() + } + } + } +} diff --git a/strong-type-derive/src/detail/basic_string.rs b/strong-type-derive/src/detail/basic_string.rs index 0d7addd..7478cb0 100644 --- a/strong-type-derive/src/detail/basic_string.rs +++ b/strong-type-derive/src/detail/basic_string.rs @@ -3,12 +3,6 @@ use quote::quote; pub(crate) fn implement_basic_string(name: &syn::Ident) -> TokenStream { quote! { - impl #name { - pub fn value(&self) -> &str { - self.0.as_str() - } - } - impl Clone for #name { fn clone(&self) -> Self { Self(self.0.clone()) @@ -22,3 +16,34 @@ pub(crate) fn implement_basic_string(name: &syn::Ident) -> TokenStream { } } } + +pub(crate) fn implement_primitive_str_accessor(name: &syn::Ident) -> TokenStream { + quote! { + impl #name { + pub fn value(&self) -> &str { + self.0.as_str() + } + + pub fn primitive(&self) -> &str { + self.value() + } + } + } +} + +pub(crate) fn implement_primitive_str_accessor_derived( + name: &syn::Ident, + value_type: &syn::Ident, +) -> TokenStream { + quote! { + impl #name { + pub fn value(&self) -> &#value_type { + &self.0 + } + + pub fn primitive(&self) -> &str { + self.0.primitive() + } + } + } +} diff --git a/strong-type-derive/src/detail/constants.rs b/strong-type-derive/src/detail/constants.rs index 04cb505..6819d60 100644 --- a/strong-type-derive/src/detail/constants.rs +++ b/strong-type-derive/src/detail/constants.rs @@ -11,3 +11,16 @@ pub(crate) fn implement_constants(name: &syn::Ident, value_type: &syn::Ident) -> } } } +pub(crate) fn implement_constants_derived( + name: &syn::Ident, + value_type: &syn::Ident, +) -> TokenStream { + quote! { + impl #name { + pub const MIN: Self = Self(#value_type::MIN); + pub const MAX: Self = Self(#value_type::MAX); + pub const ZERO: Self = Self(#value_type::ZERO); + pub const ONE: Self = Self(#value_type::ONE); + } + } +} diff --git a/strong-type-derive/src/detail/mod.rs b/strong-type-derive/src/detail/mod.rs index e112118..e564d41 100644 --- a/strong-type-derive/src/detail/mod.rs +++ b/strong-type-derive/src/detail/mod.rs @@ -9,19 +9,24 @@ mod display; mod hash; mod nan; mod negate; -mod underlying_type; +mod underlying_type_utils; mod utils; pub(crate) use arithmetic::implement_arithmetic; pub(crate) use basic::implement_basic; -pub(crate) use basic_primitive::implement_basic_primitive; -pub(crate) use basic_string::implement_basic_string; +pub(crate) use basic_primitive::{ + implement_basic_primitive, implement_primitive_accessor, implement_primitive_accessor_derived, +}; +pub(crate) use basic_string::{ + implement_basic_string, implement_primitive_str_accessor, + implement_primitive_str_accessor_derived, +}; pub(crate) use bit_ops::implement_bit_shift; pub(crate) use bool_ops::implement_bool_ops; -pub(crate) use constants::implement_constants; +pub(crate) use constants::{implement_constants, implement_constants_derived}; pub(crate) use display::implement_display; pub(crate) use hash::implement_hash; pub(crate) use nan::implement_nan; pub(crate) use negate::implement_negate; -pub(crate) use underlying_type::{get_type_group, get_type_ident, UnderlyingTypeGroup}; +pub(crate) use underlying_type_utils::{get_type, TypeInfo, UnderlyingType, ValueTypeGroup}; pub(crate) use utils::{get_attributes, is_struct_valid, StrongTypeAttributes}; diff --git a/strong-type-derive/src/detail/underlying_type.rs b/strong-type-derive/src/detail/underlying_type.rs deleted file mode 100644 index e43dd9b..0000000 --- a/strong-type-derive/src/detail/underlying_type.rs +++ /dev/null @@ -1,55 +0,0 @@ -use syn::{Data, DeriveInput, Type}; - -#[repr(u8)] -#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)] -pub(crate) enum UnderlyingTypeGroup { - Int, - Float, - UInt, - Bool, - Char, - String, -} - -pub(crate) fn get_type_ident(input: &DeriveInput) -> &syn::Ident { - if let Data::Struct(ref data_struct) = input.data { - if let Type::Path(ref path) = &data_struct.fields.iter().next().unwrap().ty { - return &path.path.segments.last().unwrap().ident; - } - } - panic!("Unsupported input") -} - -pub(crate) fn get_type_group(value_type: &syn::Ident) -> UnderlyingTypeGroup { - if value_type == "i8" - || value_type == "i16" - || value_type == "i32" - || value_type == "i64" - || value_type == "i128" - || value_type == "isize" - { - return UnderlyingTypeGroup::Int; - } - if value_type == "u8" - || value_type == "u16" - || value_type == "u32" - || value_type == "u64" - || value_type == "u128" - || value_type == "usize" - { - return UnderlyingTypeGroup::UInt; - } - if value_type == "f32" || value_type == "f64" { - return UnderlyingTypeGroup::Float; - } - if value_type == "bool" { - return UnderlyingTypeGroup::Bool; - } - if value_type == "char" { - return UnderlyingTypeGroup::Char; - } - if value_type == "String" { - return UnderlyingTypeGroup::String; - } - panic!("Unsupported type: {}", value_type); -} diff --git a/strong-type-derive/src/detail/underlying_type_utils.rs b/strong-type-derive/src/detail/underlying_type_utils.rs new file mode 100644 index 0000000..c5ee9f1 --- /dev/null +++ b/strong-type-derive/src/detail/underlying_type_utils.rs @@ -0,0 +1,104 @@ +use syn::{Data, DeriveInput, Type}; + +pub(crate) enum UnderlyingType { + Primitive, + Derived, +} +pub(crate) enum ValueTypeGroup { + Int(UnderlyingType), + Float(UnderlyingType), + UInt(UnderlyingType), + Bool(UnderlyingType), + Char(UnderlyingType), + String(UnderlyingType), +} + +pub(crate) struct TypeInfo { + pub primitive_type: syn::Ident, + pub value_type: syn::Ident, + pub type_group: ValueTypeGroup, +} + +fn get_type_ident(input: &DeriveInput) -> Option { + if let Data::Struct(ref data_struct) = input.data { + if let Type::Path(ref path) = &data_struct.fields.iter().next().unwrap().ty { + return Some(path.path.segments.last().unwrap().ident.clone()); + } + } + None +} + +fn get_primitive_from_custom_underlying(input: &DeriveInput) -> Option { + for attr in input.attrs.iter() { + if attr.path().is_ident("custom_underlying") { + let mut primitive = None; + attr.parse_nested_meta(|meta| match meta.path.get_ident() { + Some(ident) => { + primitive = Some(ident.clone()); + Ok(()) + } + None => Err(meta.error("Unsupported attribute")), + }) + .ok()?; + return primitive; + } + } + + None +} + +pub(crate) fn get_type(input: &DeriveInput) -> TypeInfo { + if let Some(value_type) = get_type_ident(input) { + match get_primitive_from_custom_underlying(input) { + Some(primitive_type) => TypeInfo { + primitive_type: primitive_type.clone(), + value_type: value_type.clone(), + type_group: get_type_group(&primitive_type, UnderlyingType::Derived), + }, + None => TypeInfo { + primitive_type: value_type.clone(), + value_type: value_type.clone(), + type_group: get_type_group(&value_type, UnderlyingType::Primitive), + }, + } + } else { + panic!("Unsupported input") + } +} + +pub(crate) fn get_type_group( + value_type: &syn::Ident, + underlying_type: UnderlyingType, +) -> ValueTypeGroup { + if value_type == "i8" + || value_type == "i16" + || value_type == "i32" + || value_type == "i64" + || value_type == "i128" + || value_type == "isize" + { + return ValueTypeGroup::Int(underlying_type); + } + if value_type == "u8" + || value_type == "u16" + || value_type == "u32" + || value_type == "u64" + || value_type == "u128" + || value_type == "usize" + { + return ValueTypeGroup::UInt(underlying_type); + } + if value_type == "f32" || value_type == "f64" { + return ValueTypeGroup::Float(underlying_type); + } + if value_type == "bool" { + return ValueTypeGroup::Bool(underlying_type); + } + if value_type == "char" { + return ValueTypeGroup::Char(underlying_type); + } + if value_type == "String" { + return ValueTypeGroup::String(underlying_type); + } + panic!("Unsupported type: {}", value_type); +} diff --git a/strong-type-derive/src/lib.rs b/strong-type-derive/src/lib.rs index 0d06ed1..ca75986 100644 --- a/strong-type-derive/src/lib.rs +++ b/strong-type-derive/src/lib.rs @@ -6,7 +6,7 @@ use syn::{parse_macro_input, DeriveInput}; use crate::strong_type::expand_strong_type; -#[proc_macro_derive(StrongType, attributes(strong_type))] +#[proc_macro_derive(StrongType, attributes(strong_type, custom_underlying))] pub fn strong_type(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_strong_type(input).into() diff --git a/strong-type-derive/src/strong_type.rs b/strong-type-derive/src/strong_type.rs index 71ac51d..4aeed85 100644 --- a/strong-type-derive/src/strong_type.rs +++ b/strong-type-derive/src/strong_type.rs @@ -1,8 +1,10 @@ use crate::detail::{ - get_attributes, get_type_group, get_type_ident, implement_arithmetic, implement_basic, - implement_basic_primitive, implement_basic_string, implement_bit_shift, implement_bool_ops, - implement_constants, implement_display, implement_hash, implement_nan, implement_negate, - is_struct_valid, StrongTypeAttributes, UnderlyingTypeGroup, + get_attributes, get_type, implement_arithmetic, implement_basic, implement_basic_primitive, + implement_basic_string, implement_bit_shift, implement_bool_ops, implement_constants, + implement_constants_derived, implement_display, implement_hash, implement_nan, + implement_negate, implement_primitive_accessor, implement_primitive_accessor_derived, + implement_primitive_str_accessor, implement_primitive_str_accessor_derived, is_struct_valid, + StrongTypeAttributes, TypeInfo, UnderlyingType, ValueTypeGroup, }; use proc_macro2::TokenStream; use quote::quote; @@ -14,64 +16,99 @@ pub(super) fn expand_strong_type(input: DeriveInput) -> TokenStream { } let name = &input.ident; - let value_type = get_type_ident(&input); - let group = get_type_group(value_type); + let TypeInfo { + primitive_type, + value_type, + type_group, + } = get_type(&input); let StrongTypeAttributes { has_auto_operators, has_custom_display, } = get_attributes(&input); let mut ast = quote!(); - ast.extend(implement_basic(name, value_type)); + ast.extend(implement_basic(name, &value_type)); if !has_custom_display { ast.extend(implement_display(name)); }; - match &group { - UnderlyingTypeGroup::Int | UnderlyingTypeGroup::UInt => { - ast.extend(implement_basic_primitive(name, value_type)); - ast.extend(implement_constants(name, value_type)); + match &type_group { + ValueTypeGroup::Int(underlying_type) + | ValueTypeGroup::UInt(underlying_type) + | ValueTypeGroup::Float(underlying_type) + | ValueTypeGroup::Bool(underlying_type) + | ValueTypeGroup::Char(underlying_type) => match underlying_type { + UnderlyingType::Primitive => { + ast.extend(implement_primitive_accessor(name, &value_type)) + } + UnderlyingType::Derived => { + ast.extend(implement_primitive_accessor_derived(name, &primitive_type)) + } + }, + ValueTypeGroup::String(UnderlyingType::Primitive) => { + ast.extend(implement_primitive_str_accessor(name)); + } + ValueTypeGroup::String(UnderlyingType::Derived) => { + ast.extend(implement_primitive_str_accessor_derived(name, &value_type)); + } + } + + match &type_group { + ValueTypeGroup::Int(underlying_type) | ValueTypeGroup::UInt(underlying_type) => { + ast.extend(implement_basic_primitive(name, &value_type)); ast.extend(implement_hash(name)); + + match underlying_type { + UnderlyingType::Primitive => ast.extend(implement_constants(name, &value_type)), + UnderlyingType::Derived => { + ast.extend(implement_constants_derived(name, &value_type)) + } + } } - UnderlyingTypeGroup::Float => { - ast.extend(implement_basic_primitive(name, value_type)); - ast.extend(implement_constants(name, value_type)); - ast.extend(implement_nan(name, value_type)); + ValueTypeGroup::Float(underlying_type) => { + ast.extend(implement_basic_primitive(name, &value_type)); + ast.extend(implement_nan(name, &value_type)); + match underlying_type { + UnderlyingType::Primitive => ast.extend(implement_constants(name, &value_type)), + UnderlyingType::Derived => { + ast.extend(implement_constants_derived(name, &value_type)) + } + } } - UnderlyingTypeGroup::Bool => { - ast.extend(implement_basic_primitive(name, value_type)); + ValueTypeGroup::Bool(_) => { + ast.extend(implement_basic_primitive(name, &value_type)); ast.extend(implement_hash(name)); } - UnderlyingTypeGroup::Char => { - ast.extend(implement_basic_primitive(name, value_type)); + ValueTypeGroup::Char(_) => { + ast.extend(implement_basic_primitive(name, &value_type)); ast.extend(implement_hash(name)); } - UnderlyingTypeGroup::String => { + ValueTypeGroup::String(_) => { ast.extend(implement_basic_string(name)); ast.extend(implement_hash(name)); } } if has_auto_operators { - match &group { - UnderlyingTypeGroup::Float => { + match &type_group { + ValueTypeGroup::Float(_) => { ast.extend(implement_arithmetic(name)); ast.extend(implement_negate(name)); } - UnderlyingTypeGroup::Int => { + ValueTypeGroup::Int(_) => { ast.extend(implement_arithmetic(name)); ast.extend(implement_negate(name)); ast.extend(implement_bit_shift(name)); } - UnderlyingTypeGroup::UInt => { + ValueTypeGroup::UInt(_) => { ast.extend(implement_arithmetic(name)); ast.extend(implement_bit_shift(name)); } - UnderlyingTypeGroup::Bool => { + ValueTypeGroup::Bool(_) => { ast.extend(implement_bool_ops(name)); } - _ => {} + ValueTypeGroup::Char(_) | ValueTypeGroup::String(_) => {} } } diff --git a/strong-type-tests/tests/custom_underlying.rs b/strong-type-tests/tests/custom_underlying.rs new file mode 100644 index 0000000..9ff8dcd --- /dev/null +++ b/strong-type-tests/tests/custom_underlying.rs @@ -0,0 +1,73 @@ +#[cfg(test)] +mod tests { + use std::mem; + use strong_type::StrongType; + + fn test_type() {} + + #[test] + fn test_custom_underlying() { + #[derive(StrongType)] + #[strong_type(auto_operators)] + struct Dollar(i32); + + #[derive(StrongType)] + #[strong_type(auto_operators)] + #[custom_underlying(i32)] + struct Cash(Dollar); + test_type::(); + assert_eq!(mem::size_of::(), mem::size_of::()); + + assert_eq!( + Cash::new(Dollar::new(10)), + Cash::new(Dollar::new(2)) + Cash::new(Dollar::new(8)) + ); + assert_eq!(Cash::new(Dollar::new(10)).primitive(), 10); + + assert_eq!( + format!("{}", Cash::new(Dollar::new(10))), + "Cash(Dollar(10))" + ); + + #[derive(StrongType)] + #[custom_underlying(i32)] + struct Coin(Cash); + test_type::(); + assert_eq!(mem::size_of::(), mem::size_of::()); + assert_eq!( + format!("{}", Coin::new(Cash::new(Dollar::new(10)))), + "Coin(Cash(Dollar(10)))" + ); + assert_eq!( + Coin::new(Cash::new(Dollar::new(10))).value(), + Cash::new(Dollar::new(10)) + ); + assert_eq!(Coin::new(Cash::new(Dollar::new(10))).primitive(), 10); + } + + #[test] + fn test_custom_string_underlying_with() { + #[derive(StrongType)] + struct Tag(String); + + #[derive(StrongType)] + #[custom_underlying(String)] + struct Name(Tag); + + test_type::(); + assert_eq!(mem::size_of::(), mem::size_of::()); + assert_eq!( + format!("{}", Name::new(Tag::new("tag".to_string()))), + "Name(Tag(tag))" + ); + + #[derive(StrongType)] + #[custom_underlying(String)] + struct Surname(Name); + assert_eq!(mem::size_of::(), mem::size_of::()); + assert_eq!( + format!("{}", Surname::new(Name::new(Tag::new("tag".to_string())))), + "Surname(Name(Tag(tag)))" + ); + } +} diff --git a/strong-type-tests/tests/tests.rs b/strong-type-tests/tests/tests.rs index 88e41f7..1d60509 100644 --- a/strong-type-tests/tests/tests.rs +++ b/strong-type-tests/tests/tests.rs @@ -1,3 +1,4 @@ mod auto_operators; +mod custom_underlying; mod display; mod strong_type;