From e1774c8ee563e794e8f2a25823f3c8dfca8036ae Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Tue, 4 Jun 2024 09:29:01 +0200 Subject: [PATCH] fix(sol-macro): allow deriving `Default` on contracts (#645) --- .../sol-macro-expander/src/expand/contract.rs | 27 ++++++++++++++++--- crates/sol-macro-input/src/attr.rs | 9 ++++--- crates/sol-macro-input/src/lib.rs | 2 +- crates/sol-types/tests/macros/sol/mod.rs | 20 ++++++++++++++ 4 files changed, 51 insertions(+), 7 deletions(-) diff --git a/crates/sol-macro-expander/src/expand/contract.rs b/crates/sol-macro-expander/src/expand/contract.rs index b45fa3bac..4790e3a84 100644 --- a/crates/sol-macro-expander/src/expand/contract.rs +++ b/crates/sol-macro-expander/src/expand/contract.rs @@ -138,23 +138,44 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, contract: &ItemContract) -> Result>(); + if derives.is_empty() { + continue; + } + + let len = derives.len(); + derives.retain(|derive| !derive.is_ident("Default")); + if derives.len() == len { + continue; + } + + attr.meta = parse_quote! { derive(#(#derives),*) }; + } let functions_enum = (!functions.is_empty()).then(|| { - let mut attrs = item_attrs.clone(); + let mut attrs = enum_attrs.clone(); let doc_str = format!("Container for all the [`{name}`](self) function calls."); attrs.push(parse_quote!(#[doc = #doc_str])); enum_expander.expand(ToExpand::Functions(&functions), attrs) }); let errors_enum = (!errors.is_empty()).then(|| { - let mut attrs = item_attrs.clone(); + let mut attrs = enum_attrs.clone(); let doc_str = format!("Container for all the [`{name}`](self) custom errors."); attrs.push(parse_quote!(#[doc = #doc_str])); enum_expander.expand(ToExpand::Errors(&errors), attrs) }); let events_enum = (!events.is_empty()).then(|| { - let mut attrs = item_attrs; + let mut attrs = enum_attrs; let doc_str = format!("Container for all the [`{name}`](self) events."); attrs.push(parse_quote!(#[doc = #doc_str])); enum_expander.expand(ToExpand::Events(&events), attrs) diff --git a/crates/sol-macro-input/src/attr.rs b/crates/sol-macro-input/src/attr.rs index a7ab4dd3d..eefc15707 100644 --- a/crates/sol-macro-input/src/attr.rs +++ b/crates/sol-macro-input/src/attr.rs @@ -57,9 +57,12 @@ pub fn derives(attrs: &[Attribute]) -> impl Iterator { /// Returns an iterator over all the rust `::` paths in the `#[derive(...)]` /// attributes. pub fn derives_mapped(attrs: &[Attribute]) -> impl Iterator + '_ { - derives(attrs).flat_map(|attr| { - attr.parse_args_with(Punctuated::::parse_terminated).unwrap_or_default() - }) + derives(attrs).flat_map(parse_derives) +} + +/// Parses the `#[derive(...)]` attributes into a list of paths. +pub fn parse_derives(attr: &Attribute) -> Punctuated { + attr.parse_args_with(Punctuated::::parse_terminated).unwrap_or_default() } // When adding a new attribute: diff --git a/crates/sol-macro-input/src/lib.rs b/crates/sol-macro-input/src/lib.rs index ed0e7c505..f6fa4a540 100644 --- a/crates/sol-macro-input/src/lib.rs +++ b/crates/sol-macro-input/src/lib.rs @@ -18,7 +18,7 @@ extern crate syn_solidity as ast; /// Tools for working with `#[...]` attributes. mod attr; -pub use attr::{derives_mapped, docs_str, mk_doc, ContainsSolAttrs, SolAttrs}; +pub use attr::{derives_mapped, docs_str, mk_doc, parse_derives, ContainsSolAttrs, SolAttrs}; mod input; pub use input::{SolInput, SolInputKind}; diff --git a/crates/sol-types/tests/macros/sol/mod.rs b/crates/sol-types/tests/macros/sol/mod.rs index a6504b629..ae79a5a9f 100644 --- a/crates/sol-types/tests/macros/sol/mod.rs +++ b/crates/sol-types/tests/macros/sol/mod.rs @@ -876,3 +876,23 @@ fn event_overrides() { assert_eq!(two::TestEvent_1::SIGNATURE, "TestEvent(bytes32,bytes32)"); assert_eq!(two::TestEvent_1::SIGNATURE_HASH, keccak256("TestEvent(bytes32,bytes32)")); } + +#[test] +fn contract_derive_default() { + sol! { + #[derive(Debug, Default)] + contract MyContract { + function f1(); + function f2(); + event e1(); + event e2(); + error c(); + } + } + + let MyContract::f1Call {} = MyContract::f1Call::default(); + let MyContract::f2Call {} = MyContract::f2Call::default(); + let MyContract::e1 {} = MyContract::e1::default(); + let MyContract::e2 {} = MyContract::e2::default(); + let MyContract::c {} = MyContract::c::default(); +}