Skip to content

Commit

Permalink
NestedStrongType with custom_underlying
Browse files Browse the repository at this point in the history
  • Loading branch information
yunjhongwu committed Feb 2, 2024
1 parent 72e6c45 commit 7a5b2f6
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 60 deletions.
13 changes: 13 additions & 0 deletions strong-type-derive/src/detail/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,16 @@ pub(crate) fn implement_constants(name: &syn::Ident, value_type: &syn::Ident) ->
}
}
}
pub(crate) fn implement_constants_derive(
name: &syn::Ident,
value_type: &syn::Ident,
) -> TokenStream {
quote! {
impl #name {
pub const MIN: Self = Self(#value_type::MIN);
pub const MAX: Self = Self(#value_type::MAX);
pub const ZERO: Self = Self(#value_type::ZERO);
pub const ONE: Self = Self(#value_type::ONE);
}
}
}
4 changes: 2 additions & 2 deletions strong-type-derive/src/detail/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ pub(crate) use basic_primitive::implement_basic_primitive;
pub(crate) use basic_string::implement_basic_string;
pub(crate) use bit_ops::implement_bit_shift;
pub(crate) use bool_ops::implement_bool_ops;
pub(crate) use constants::implement_constants;
pub(crate) use constants::{implement_constants, implement_constants_derive};
pub(crate) use display::implement_display;
pub(crate) use hash::implement_hash;
pub(crate) use nan::implement_nan;
pub(crate) use negate::implement_negate;
pub(crate) use underlying_type::{get_type_group, get_type_ident, UnderlyingTypeGroup};
pub(crate) use underlying_type::{get_type, TypeInfo, UnderlyingType, ValueTypeGroup};
pub(crate) use utils::{get_attributes, is_struct_valid, StrongTypeAttributes};
120 changes: 83 additions & 37 deletions strong-type-derive/src/detail/underlying_type.rs
Original file line number Diff line number Diff line change
@@ -1,55 +1,101 @@
use syn::{Data, DeriveInput, Type};

#[repr(u8)]
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Copy, Clone)]
pub(crate) enum UnderlyingTypeGroup {
Int,
Float,
UInt,
Bool,
Char,
String,
pub(crate) enum UnderlyingType {
Primitive,
Derived,
}
pub(crate) enum ValueTypeGroup {
Int(UnderlyingType),
Float(UnderlyingType),
UInt(UnderlyingType),
Bool(UnderlyingType),
Char(UnderlyingType),
String(UnderlyingType),
}

pub(crate) struct TypeInfo<'a> {
pub value_type: &'a syn::Ident,
pub type_group: ValueTypeGroup,
}

pub(crate) fn get_type_ident(input: &DeriveInput) -> &syn::Ident {
fn get_type_ident(input: &DeriveInput) -> Option<&syn::Ident> {
if let Data::Struct(ref data_struct) = input.data {
if let Type::Path(ref path) = &data_struct.fields.iter().next().unwrap().ty {
return &path.path.segments.last().unwrap().ident;
return Some(&path.path.segments.last().unwrap().ident);
}
}
None
}

fn get_group_from_custom_underlying(input: &DeriveInput) -> Option<ValueTypeGroup> {
for attr in input.attrs.iter() {
if attr.path().is_ident("custom_underlying") {
let mut type_group = None;
attr.parse_nested_meta(|meta| {
if let Some(ident) = meta.path.get_ident() {
type_group = Some(get_type_group(ident, UnderlyingType::Derived));
Ok(())
} else {
Err(meta.error("Unsupported attribute"))
}
})
.ok()?;
return type_group;
}
}
panic!("Unsupported input")

None
}

pub(crate) fn get_type_group(value_type: &syn::Ident) -> UnderlyingTypeGroup {
if value_type == "i8"
|| value_type == "i16"
|| value_type == "i32"
|| value_type == "i64"
|| value_type == "i128"
|| value_type == "isize"
pub(crate) fn get_type(input: &DeriveInput) -> TypeInfo {
if let Some(value_type) = get_type_ident(input) {
let type_group = match get_group_from_custom_underlying(input) {
Some(type_group) => type_group,
None => get_type_group(value_type, UnderlyingType::Primitive),
};

TypeInfo {
value_type,
type_group,
}
} else {
panic!("Unsupported input")
}
}

pub(crate) fn get_type_group(
underlying_type: &syn::Ident,
value_type: UnderlyingType,
) -> ValueTypeGroup {
if underlying_type == "i8"
|| underlying_type == "i16"
|| underlying_type == "i32"
|| underlying_type == "i64"
|| underlying_type == "i128"
|| underlying_type == "isize"
{
return UnderlyingTypeGroup::Int;
}
if value_type == "u8"
|| value_type == "u16"
|| value_type == "u32"
|| value_type == "u64"
|| value_type == "u128"
|| value_type == "usize"
return ValueTypeGroup::Int(value_type);
}
if underlying_type == "u8"
|| underlying_type == "u16"
|| underlying_type == "u32"
|| underlying_type == "u64"
|| underlying_type == "u128"
|| underlying_type == "usize"
{
return UnderlyingTypeGroup::UInt;
return ValueTypeGroup::UInt(value_type);
}
if value_type == "f32" || value_type == "f64" {
return UnderlyingTypeGroup::Float;
if underlying_type == "f32" || underlying_type == "f64" {
return ValueTypeGroup::Float(value_type);
}
if value_type == "bool" {
return UnderlyingTypeGroup::Bool;
if underlying_type == "bool" {
return ValueTypeGroup::Bool(value_type);
}
if value_type == "char" {
return UnderlyingTypeGroup::Char;
if underlying_type == "char" {
return ValueTypeGroup::Char(value_type);
}
if value_type == "String" {
return UnderlyingTypeGroup::String;
if underlying_type == "String" {
return ValueTypeGroup::String(value_type);
}
panic!("Unsupported type: {}", value_type);
panic!("Unsupported type: {}", underlying_type);
}
2 changes: 1 addition & 1 deletion strong-type-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use syn::{parse_macro_input, DeriveInput};

use crate::strong_type::expand_strong_type;

#[proc_macro_derive(StrongType, attributes(strong_type))]
#[proc_macro_derive(StrongType, attributes(strong_type, custom_underlying))]
pub fn strong_type(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
expand_strong_type(input).into()
Expand Down
49 changes: 29 additions & 20 deletions strong-type-derive/src/strong_type.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::detail::{
get_attributes, get_type_group, get_type_ident, implement_arithmetic, implement_basic,
implement_basic_primitive, implement_basic_string, implement_bit_shift, implement_bool_ops,
implement_constants, implement_display, implement_hash, implement_nan, implement_negate,
is_struct_valid, StrongTypeAttributes, UnderlyingTypeGroup,
get_attributes, get_type, implement_arithmetic, implement_basic, implement_basic_primitive,
implement_basic_string, implement_bit_shift, implement_bool_ops, implement_constants,
implement_constants_derive, implement_display, implement_hash, implement_nan, implement_negate,
is_struct_valid, StrongTypeAttributes, TypeInfo, UnderlyingType, ValueTypeGroup,
};
use proc_macro2::TokenStream;
use quote::quote;
Expand All @@ -14,8 +14,10 @@ pub(super) fn expand_strong_type(input: DeriveInput) -> TokenStream {
}

let name = &input.ident;
let value_type = get_type_ident(&input);
let group = get_type_group(value_type);
let TypeInfo {
value_type,
type_group,
} = get_type(&input);
let StrongTypeAttributes {
has_auto_operators,
has_custom_display,
Expand All @@ -28,50 +30,57 @@ pub(super) fn expand_strong_type(input: DeriveInput) -> TokenStream {
ast.extend(implement_display(name));
};

match &group {
UnderlyingTypeGroup::Int | UnderlyingTypeGroup::UInt => {
match &type_group {
ValueTypeGroup::Int(underlying_type) | ValueTypeGroup::UInt(underlying_type) => {
ast.extend(implement_basic_primitive(name, value_type));
ast.extend(implement_constants(name, value_type));
ast.extend(implement_hash(name));

match underlying_type {
UnderlyingType::Primitive => ast.extend(implement_constants(name, value_type)),
UnderlyingType::Derived => ast.extend(implement_constants_derive(name, value_type)),
}
}
UnderlyingTypeGroup::Float => {
ValueTypeGroup::Float(underlying_type) => {
ast.extend(implement_basic_primitive(name, value_type));
ast.extend(implement_constants(name, value_type));
ast.extend(implement_nan(name, value_type));
match underlying_type {
UnderlyingType::Primitive => ast.extend(implement_constants(name, value_type)),
UnderlyingType::Derived => ast.extend(implement_constants_derive(name, value_type)),
}
}
UnderlyingTypeGroup::Bool => {
ValueTypeGroup::Bool(_) => {
ast.extend(implement_basic_primitive(name, value_type));
ast.extend(implement_hash(name));
}
UnderlyingTypeGroup::Char => {
ValueTypeGroup::Char(_) => {
ast.extend(implement_basic_primitive(name, value_type));
ast.extend(implement_hash(name));
}
UnderlyingTypeGroup::String => {
ValueTypeGroup::String(_) => {
ast.extend(implement_basic_string(name));
ast.extend(implement_hash(name));
}
}

if has_auto_operators {
match &group {
UnderlyingTypeGroup::Float => {
match &type_group {
ValueTypeGroup::Float(_) => {
ast.extend(implement_arithmetic(name));
ast.extend(implement_negate(name));
}
UnderlyingTypeGroup::Int => {
ValueTypeGroup::Int(_) => {
ast.extend(implement_arithmetic(name));
ast.extend(implement_negate(name));
ast.extend(implement_bit_shift(name));
}
UnderlyingTypeGroup::UInt => {
ValueTypeGroup::UInt(_) => {
ast.extend(implement_arithmetic(name));
ast.extend(implement_bit_shift(name));
}
UnderlyingTypeGroup::Bool => {
ValueTypeGroup::Bool(_) => {
ast.extend(implement_bool_ops(name));
}
_ => {}
ValueTypeGroup::Char(_) | ValueTypeGroup::String(_) => {}
}
}

Expand Down
41 changes: 41 additions & 0 deletions strong-type-tests/tests/custom_underlying.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#[cfg(test)]
mod tests {
use std::mem;
use strong_type::StrongType;

fn test_type<T: std::fmt::Debug + Clone + Send + Sync + Default + PartialEq>() {}

#[test]
fn test_custom_underlying() {
#[derive(StrongType)]
#[strong_type(auto_operators)]
struct Dollar(i32);

#[derive(StrongType)]
#[strong_type(auto_operators)]
#[custom_underlying(i32)]
struct Cash(Dollar);
test_type::<Cash>();
assert_eq!(mem::size_of::<Cash>(), mem::size_of::<i32>());

assert_eq!(
Cash::new(Dollar::new(10)),
Cash::new(Dollar::new(2)) + Cash::new(Dollar::new(8))
);

assert_eq!(
format!("{}", Cash::new(Dollar::new(10))),
"Cash(Dollar(10))"
);

#[derive(StrongType)]
#[custom_underlying(i32)]
struct Coin(Cash);
test_type::<Coin>();
assert_eq!(mem::size_of::<Coin>(), mem::size_of::<i32>());
assert_eq!(
format!("{}", Coin::new(Cash::new(Dollar::new(10)))),
"Coin(Cash(Dollar(10)))"
);
}
}
1 change: 1 addition & 0 deletions strong-type-tests/tests/tests.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod auto_operators;
mod custom_underlying;
mod display;
mod strong_type;

0 comments on commit 7a5b2f6

Please sign in to comment.