diff --git a/utils/tfhe-versionable-derive/src/versionize_attribute.rs b/utils/tfhe-versionable-derive/src/versionize_attribute.rs index 36ee2e08c1..0a0334ddb7 100644 --- a/utils/tfhe-versionable-derive/src/versionize_attribute.rs +++ b/utils/tfhe-versionable-derive/src/versionize_attribute.rs @@ -106,7 +106,14 @@ impl VersionizeAttribute { attribute_builder.into = Some(parse_path_ignore_quotes(&name_value.value)?); } - // parse versionize(bound = "Type: Bound") + // parse versionize(dispatch = "Type") + } else if name_value.path.is_ident("dispatch") { + if attribute_builder.dispatch_enum.is_some() { + return Err(Self::default_error(meta.span())); + } else { + attribute_builder.dispatch_enum = + Some(parse_path_ignore_quotes(&name_value.value)?); + } } else { return Err(Self::default_error(meta.span())); } diff --git a/utils/tfhe-versionable/examples/simple.rs b/utils/tfhe-versionable/examples/simple.rs index ce5871c9cd..16d0cb515f 100644 --- a/utils/tfhe-versionable/examples/simple.rs +++ b/utils/tfhe-versionable/examples/simple.rs @@ -6,8 +6,9 @@ use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispa // The structure that should be versioned, as defined in your code #[derive(Versionize)] -#[versionize(MyStructVersions)] // Link to the enum type that will holds all the versions of this - // type +// We have to link to the enum type that will holds all the versions of this +// type. This can also be written `#[versionize(dispatch = MyStructVersions)]`. +#[versionize(MyStructVersions)] struct MyStruct { attr: T, builtin: u32, diff --git a/utils/tfhe-versionable/tests/testcases/struct.rs b/utils/tfhe-versionable/tests/testcases/struct.rs index a951554eaa..d50747067b 100644 --- a/utils/tfhe-versionable/tests/testcases/struct.rs +++ b/utils/tfhe-versionable/tests/testcases/struct.rs @@ -37,6 +37,18 @@ pub struct MyStruct2 { field1: U, } +#[derive(Versionize)] +#[versionize(dispatch = MyStruct3Versions)] +pub struct MyStruct3 { + field0: u64, + field1: T, +} + +#[derive(VersionsDispatch)] +pub enum MyStruct3Versions { + V0(MyStruct3), +} + fn main() { assert_impl_all!(MyEmptyStruct: Version); assert_impl_all!(MyEmptyStruct2: Version); @@ -47,7 +59,9 @@ fn main() { assert_impl_all!(MyAnonStruct3: Version); - assert_impl_all!(MyStruct: Version); + assert_impl_all!(MyStruct: Versionize); assert_impl_all!(MyStruct2: Version); + + assert_impl_all!(MyStruct3: Versionize); }