diff --git a/Cargo.lock b/Cargo.lock index 1a493135570..77dfe7147a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -931,6 +931,7 @@ dependencies = [ "convert_case", "env_logger", "genco", + "indent", "indoc", "itertools 0.11.0", "log", @@ -1742,6 +1743,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "indent" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9f1a0777d972970f204fdf8ef319f1f4f8459131636d7e3c96c5d59570d0fa6" + [[package]] name = "indexmap" version = "1.9.3" diff --git a/Cargo.toml b/Cargo.toml index 29058e2d68b..054adf2b58d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,6 +71,7 @@ genco = "0.17.0" good_lp = { version = "1.3.2", features = ["minilp"], default-features = false } id-arena = "2.2.1" ignore = "0.4.20" +indent = "0.1.1" indexmap = { version = "2.0.0", features = ["serde"] } indoc = "2.0.1" itertools = "0.11.0" diff --git a/crates/cairo-lang-starknet/Cargo.toml b/crates/cairo-lang-starknet/Cargo.toml index 90bf04ee2ef..7de3416b127 100644 --- a/crates/cairo-lang-starknet/Cargo.toml +++ b/crates/cairo-lang-starknet/Cargo.toml @@ -39,6 +39,7 @@ serde_json.workspace = true sha3.workspace = true smol_str.workspace = true thiserror.workspace = true +indent.workspace = true [dev-dependencies] cairo-lang-diagnostics = { path = "../cairo-lang-diagnostics" } diff --git a/crates/cairo-lang-starknet/cairo_level_tests/storage_access.cairo b/crates/cairo-lang-starknet/cairo_level_tests/storage_access.cairo index 936f9605a9f..b7d6b648439 100644 --- a/crates/cairo-lang-starknet/cairo_level_tests/storage_access.cairo +++ b/crates/cairo-lang-starknet/cairo_level_tests/storage_access.cairo @@ -41,6 +41,13 @@ impl TupleStructureStorePacking of starknet::StorePacking PluginResult { - if !derive_event_needed(&enum_ast, db) { - return PluginResult::default(); - } - let mut builder = PatchBuilder::new(db); let mut diagnostics = vec![]; let enum_name = RewriteNode::new_trimmed(enum_ast.name(db).as_syntax_node()); diff --git a/crates/cairo-lang-starknet/src/plugin/mod.rs b/crates/cairo-lang-starknet/src/plugin/mod.rs index a01535630b5..e568f17b788 100644 --- a/crates/cairo-lang-starknet/src/plugin/mod.rs +++ b/crates/cairo-lang-starknet/src/plugin/mod.rs @@ -25,7 +25,6 @@ use events::derive_event_needed; use storage_access::derive_storage_access_needed; use self::contract::{handle_contract_by_storage, handle_module}; -use self::events::handle_enum; #[derive(Debug, Default)] #[non_exhaustive] @@ -45,7 +44,12 @@ impl MacroPlugin for StarkNetPlugin { ast::Item::Struct(struct_ast) if struct_ast.has_attr(db, STORAGE_ATTR) => { handle_contract_by_storage(db, struct_ast).unwrap_or_default() } - ast::Item::Enum(enum_ast) => handle_enum(db, enum_ast), + ast::Item::Enum(enum_ast) if derive_storage_access_needed(&enum_ast, db) => { + storage_access::handle_enum(db, enum_ast) + } + ast::Item::Enum(enum_ast) if derive_event_needed(&enum_ast, db) => { + events::handle_enum(db, enum_ast) + } // Nothing to do for other items. _ => PluginResult::default(), } diff --git a/crates/cairo-lang-starknet/src/plugin/storage_access.rs b/crates/cairo-lang-starknet/src/plugin/storage_access.rs index 6a8781b041c..8151dd6ee96 100644 --- a/crates/cairo-lang-starknet/src/plugin/storage_access.rs +++ b/crates/cairo-lang-starknet/src/plugin/storage_access.rs @@ -7,7 +7,8 @@ use cairo_lang_syntax::attribute::structured::{ }; use cairo_lang_syntax::node::db::SyntaxGroup; use cairo_lang_syntax::node::helpers::QueryAttrs; -use cairo_lang_syntax::node::{ast, TypedSyntaxNode}; +use cairo_lang_syntax::node::{ast, Terminal, TypedSyntaxNode}; +use indent::indent_by; use indoc::formatdoc; /// Derive the `Store` trait for structs annotated with `derive(starknet::Store)`. @@ -153,6 +154,130 @@ pub fn handle_struct(db: &dyn SyntaxGroup, struct_ast: ast::ItemStruct) -> Plugi } } +/// Derive the `StorageAccess` trait for structs annotated with `derive(starknet::Store)`. +pub fn handle_enum(db: &dyn SyntaxGroup, enum_ast: ast::ItemEnum) -> PluginResult { + let enum_name = enum_ast.name(db).as_syntax_node().get_text_without_trivia(db); + let mut match_idx = Vec::new(); + let mut match_idx_at_offset = Vec::new(); + + let mut match_value = Vec::new(); + let mut match_value_at_offset = Vec::new(); + + let mut match_size = "".to_string(); + + for (i, variant) in enum_ast.variants(db).elements(db).iter().enumerate() { + let variant_name = variant.name(db).text(db); + let variant_type = match variant.type_clause(db) { + ast::OptionTypeClause::Empty(_) => "()".to_string(), + ast::OptionTypeClause::TypeClause(tc) => { + tc.ty(db).as_syntax_node().get_text_without_trivia(db) + } + }; + + match_idx.push(formatdoc!( + "if idx == {i} {{ + starknet::SyscallResult::Ok( + {enum_name}::{variant_name}( + starknet::Store::read_at_offset(address_domain, base, 1_u8)? + ) + ) + }}", + )); + match_idx_at_offset.push(formatdoc!( + "if idx == {i} {{ + starknet::SyscallResult::Ok( + {enum_name}::{variant_name}( + starknet::Store::read_at_offset(address_domain, base, offset + 1_u8)? + ) + ) + }}", + )); + match_value.push(formatdoc!( + "{enum_name}::{variant_name}(x) => {{ + starknet::Store::write(address_domain, base, {i})?; + starknet::Store::write_at_offset(address_domain, base, 1_u8, x)?; + }}" + )); + match_value_at_offset.push(formatdoc!( + "{enum_name}::{variant_name}(x) => {{ + starknet::Store::write_at_offset(address_domain, base, offset, {i})?; + starknet::Store::write_at_offset(address_domain, base, offset + 1_u8, x)?; + }}" + )); + + if match_size.is_empty() { + match_size = format!("starknet::Store::<{variant_type}>::size()"); + } else { + match_size = + format!("cmp::max(starknet::Store::<{variant_type}>::size(), {match_size})"); + } + } + + let sa_impl = formatdoc!( + " + impl Store{enum_name} of starknet::Store::<{enum_name}> {{ + fn read(address_domain: u32, base: starknet::StorageBaseAddress) -> \ + starknet::SyscallResult<{enum_name}> {{ + let idx = starknet::Store::::read(address_domain, base)?; + {match_idx} + else {{ + let mut message = Default::default(); + message.append('Incorrect index:'); + message.append(idx); + starknet::SyscallResult::Err(message) + }} + }} + fn write(address_domain: u32, base: starknet::StorageBaseAddress, value: {enum_name}) \ + -> starknet::SyscallResult<()> {{ + match value {{ + {match_value} + }}; + starknet::SyscallResult::Ok(()) + }} + fn read_at_offset(address_domain: u32, base: starknet::StorageBaseAddress, offset: u8) \ + -> starknet::SyscallResult<{enum_name}> {{ + let idx = starknet::Store::::read_at_offset(address_domain, base, \ + offset)?; + {match_idx_at_offset} + else {{ + let mut message = Default::default(); + message.append('Incorrect index:'); + message.append(idx); + starknet::SyscallResult::Err(message) + }} + }} + #[inline(always)] + fn write_at_offset(address_domain: u32, base: starknet::StorageBaseAddress, offset: \ + u8, value: {enum_name}) -> starknet::SyscallResult<()> {{ + match value {{ + {match_value_at_offset} + }}; + starknet::SyscallResult::Ok(()) + }} + #[inline(always)] + fn size() -> u8 {{ + 1_u8 + {match_size} + }} + }}", + match_idx = indent_by(8, match_idx.join("\nelse ")), + match_idx_at_offset = indent_by(8, match_idx_at_offset.join("\nelse ")), + match_value = indent_by(12, match_value.join(",\n")), + match_value_at_offset = indent_by(12, match_value_at_offset.join(",\n")), + ); + + let diagnostics = vec![]; + + PluginResult { + code: Some(PluginGeneratedFile { + name: "storage_access_impl".into(), + content: sa_impl, + aux_data: DynGeneratedFileAuxData(Arc::new(TrivialPluginAuxData {})), + }), + diagnostics, + remove_original_item: false, + } +} + /// Returns true if the type should be derived as a storage_access. pub fn derive_storage_access_needed(with_attrs: &T, db: &dyn SyntaxGroup) -> bool { with_attrs.query_attr(db, "derive").into_iter().any(|attr| {