Skip to content

Commit

Permalink
Derive StorageAccess impl for enums (#3460)
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejka authored Jul 16, 2023
1 parent 5134fd1 commit 1d645d0
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 9 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ genco = "0.17.0"
good_lp = { version = "1.3.2", features = ["minilp"], default-features = false }
id-arena = "2.2.1"
ignore = "0.4.20"
indent = "0.1.1"
indexmap = { version = "2.0.0", features = ["serde"] }
indoc = "2.0.1"
itertools = "0.11.0"
Expand Down
1 change: 1 addition & 0 deletions crates/cairo-lang-starknet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ serde_json.workspace = true
sha3.workspace = true
smol_str.workspace = true
thiserror.workspace = true
indent.workspace = true

[dev-dependencies]
cairo-lang-diagnostics = { path = "../cairo-lang-diagnostics" }
Expand Down
13 changes: 11 additions & 2 deletions crates/cairo-lang-starknet/cairo_level_tests/storage_access.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ impl TupleStructureStorePacking of starknet::StorePacking<TupleStructure, (felt2
}
}

#[derive(Drop, Serde, PartialEq, Copy, starknet::Store)]
enum Efg {
E: (),
F: (),
G: u256
}

#[derive(Drop, Serde, PartialEq, Copy, starknet::Store)]
struct AbcEtc {
a: u8,
Expand All @@ -56,6 +63,8 @@ struct AbcEtc {
k: EthAddress,
abc: Abc,
ts: TupleStructure,
efg1: Efg,
efg2: Efg,
}


Expand All @@ -80,7 +89,7 @@ mod test_contract {
}

#[test]
#[available_gas(2000000)]
#[available_gas(10000000)]
fn write_read_struct() {
let x = AbcEtc {
a: 1_u8,
Expand All @@ -98,7 +107,7 @@ fn write_read_struct() {
a: 1_u8, b: 2_u16, c: 3_u32,
}, ts: TupleStructure {
v1: 1_u256, v2: 2_u256,
}
}, efg1: Efg::E(()), efg2: Efg::G(123_u256)
};

assert(test_contract::__external::set_data(serialized_element(*@x)).is_empty(), 'Not empty');
Expand Down
4 changes: 0 additions & 4 deletions crates/cairo-lang-starknet/src/plugin/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,6 @@ fn get_field_kind_for_variant(

/// Derive the `Event` trait for enums annotated with `derive(starknet::Event)`.
pub fn handle_enum(db: &dyn SyntaxGroup, enum_ast: ast::ItemEnum) -> PluginResult {
if !derive_event_needed(&enum_ast, db) {
return PluginResult::default();
}

let mut builder = PatchBuilder::new(db);
let mut diagnostics = vec![];
let enum_name = RewriteNode::new_trimmed(enum_ast.name(db).as_syntax_node());
Expand Down
8 changes: 6 additions & 2 deletions crates/cairo-lang-starknet/src/plugin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use events::derive_event_needed;
use storage_access::derive_storage_access_needed;

use self::contract::{handle_contract_by_storage, handle_module};
use self::events::handle_enum;

#[derive(Debug, Default)]
#[non_exhaustive]
Expand All @@ -45,7 +44,12 @@ impl MacroPlugin for StarkNetPlugin {
ast::Item::Struct(struct_ast) if struct_ast.has_attr(db, STORAGE_ATTR) => {
handle_contract_by_storage(db, struct_ast).unwrap_or_default()
}
ast::Item::Enum(enum_ast) => handle_enum(db, enum_ast),
ast::Item::Enum(enum_ast) if derive_storage_access_needed(&enum_ast, db) => {
storage_access::handle_enum(db, enum_ast)
}
ast::Item::Enum(enum_ast) if derive_event_needed(&enum_ast, db) => {
events::handle_enum(db, enum_ast)
}
// Nothing to do for other items.
_ => PluginResult::default(),
}
Expand Down
127 changes: 126 additions & 1 deletion crates/cairo-lang-starknet/src/plugin/storage_access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use cairo_lang_syntax::attribute::structured::{
};
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::helpers::QueryAttrs;
use cairo_lang_syntax::node::{ast, TypedSyntaxNode};
use cairo_lang_syntax::node::{ast, Terminal, TypedSyntaxNode};
use indent::indent_by;
use indoc::formatdoc;

/// Derive the `Store` trait for structs annotated with `derive(starknet::Store)`.
Expand Down Expand Up @@ -153,6 +154,130 @@ pub fn handle_struct(db: &dyn SyntaxGroup, struct_ast: ast::ItemStruct) -> Plugi
}
}

/// Derive the `StorageAccess` trait for structs annotated with `derive(starknet::Store)`.
pub fn handle_enum(db: &dyn SyntaxGroup, enum_ast: ast::ItemEnum) -> PluginResult {
let enum_name = enum_ast.name(db).as_syntax_node().get_text_without_trivia(db);
let mut match_idx = Vec::new();
let mut match_idx_at_offset = Vec::new();

let mut match_value = Vec::new();
let mut match_value_at_offset = Vec::new();

let mut match_size = "".to_string();

for (i, variant) in enum_ast.variants(db).elements(db).iter().enumerate() {
let variant_name = variant.name(db).text(db);
let variant_type = match variant.type_clause(db) {
ast::OptionTypeClause::Empty(_) => "()".to_string(),
ast::OptionTypeClause::TypeClause(tc) => {
tc.ty(db).as_syntax_node().get_text_without_trivia(db)
}
};

match_idx.push(formatdoc!(
"if idx == {i} {{
starknet::SyscallResult::Ok(
{enum_name}::{variant_name}(
starknet::Store::read_at_offset(address_domain, base, 1_u8)?
)
)
}}",
));
match_idx_at_offset.push(formatdoc!(
"if idx == {i} {{
starknet::SyscallResult::Ok(
{enum_name}::{variant_name}(
starknet::Store::read_at_offset(address_domain, base, offset + 1_u8)?
)
)
}}",
));
match_value.push(formatdoc!(
"{enum_name}::{variant_name}(x) => {{
starknet::Store::write(address_domain, base, {i})?;
starknet::Store::write_at_offset(address_domain, base, 1_u8, x)?;
}}"
));
match_value_at_offset.push(formatdoc!(
"{enum_name}::{variant_name}(x) => {{
starknet::Store::write_at_offset(address_domain, base, offset, {i})?;
starknet::Store::write_at_offset(address_domain, base, offset + 1_u8, x)?;
}}"
));

if match_size.is_empty() {
match_size = format!("starknet::Store::<{variant_type}>::size()");
} else {
match_size =
format!("cmp::max(starknet::Store::<{variant_type}>::size(), {match_size})");
}
}

let sa_impl = formatdoc!(
"
impl Store{enum_name} of starknet::Store::<{enum_name}> {{
fn read(address_domain: u32, base: starknet::StorageBaseAddress) -> \
starknet::SyscallResult<{enum_name}> {{
let idx = starknet::Store::<felt252>::read(address_domain, base)?;
{match_idx}
else {{
let mut message = Default::default();
message.append('Incorrect index:');
message.append(idx);
starknet::SyscallResult::Err(message)
}}
}}
fn write(address_domain: u32, base: starknet::StorageBaseAddress, value: {enum_name}) \
-> starknet::SyscallResult<()> {{
match value {{
{match_value}
}};
starknet::SyscallResult::Ok(())
}}
fn read_at_offset(address_domain: u32, base: starknet::StorageBaseAddress, offset: u8) \
-> starknet::SyscallResult<{enum_name}> {{
let idx = starknet::Store::<felt252>::read_at_offset(address_domain, base, \
offset)?;
{match_idx_at_offset}
else {{
let mut message = Default::default();
message.append('Incorrect index:');
message.append(idx);
starknet::SyscallResult::Err(message)
}}
}}
#[inline(always)]
fn write_at_offset(address_domain: u32, base: starknet::StorageBaseAddress, offset: \
u8, value: {enum_name}) -> starknet::SyscallResult<()> {{
match value {{
{match_value_at_offset}
}};
starknet::SyscallResult::Ok(())
}}
#[inline(always)]
fn size() -> u8 {{
1_u8 + {match_size}
}}
}}",
match_idx = indent_by(8, match_idx.join("\nelse ")),
match_idx_at_offset = indent_by(8, match_idx_at_offset.join("\nelse ")),
match_value = indent_by(12, match_value.join(",\n")),
match_value_at_offset = indent_by(12, match_value_at_offset.join(",\n")),
);

let diagnostics = vec![];

PluginResult {
code: Some(PluginGeneratedFile {
name: "storage_access_impl".into(),
content: sa_impl,
aux_data: DynGeneratedFileAuxData(Arc::new(TrivialPluginAuxData {})),
}),
diagnostics,
remove_original_item: false,
}
}

/// Returns true if the type should be derived as a storage_access.
pub fn derive_storage_access_needed<T: QueryAttrs>(with_attrs: &T, db: &dyn SyntaxGroup) -> bool {
with_attrs.query_attr(db, "derive").into_iter().any(|attr| {
Expand Down

0 comments on commit 1d645d0

Please sign in to comment.