Skip to content

Commit

Permalink
add more ensure_* fns to MastForestBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
sergerad committed Jul 19, 2024
1 parent 88eb223 commit c53df25
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 99 deletions.
10 changes: 3 additions & 7 deletions assembly/src/assembler/basic_block_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ use super::{
Instruction,
};
use alloc::{borrow::Borrow, string::ToString, vec::Vec};
use vm_core::{
mast::{MastNode, MastNodeId},
AdviceInjector, AssemblyOp, Operation,
};
use vm_core::{mast::MastNodeId, AdviceInjector, AssemblyOp, Operation};

// BASIC BLOCK BUILDER
// ================================================================================================
Expand Down Expand Up @@ -134,10 +131,9 @@ impl BasicBlockBuilder {
let ops = self.ops.drain(..).collect();
let decorators = self.decorators.drain(..).collect();

let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators);
let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node);
let block_node_id = mast_forest_builder.ensure_block(ops, Some(decorators));

Some(basic_block_node_id)
Some(block_node_id)
} else if !self.decorators.is_empty() {
// this is a bug in the assembler. we shouldn't have decorators added without their
// associated operations
Expand Down
26 changes: 9 additions & 17 deletions assembly/src/assembler/instruction/procedures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
};

use smallvec::SmallVec;
use vm_core::mast::{MastForest, MastNode, MastNodeId};
use vm_core::mast::{MastForest, MastNodeId};

/// Procedure Invocation
impl Assembler {
Expand Down Expand Up @@ -94,34 +94,28 @@ impl Assembler {
mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)
mast_forest_builder.ensure_external(mast_root)
})
}
InvokeKind::Call => {
let callee_id =
mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)
mast_forest_builder.ensure_external(mast_root)
});

let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(call_node)
mast_forest_builder.ensure_call(callee_id)
}
InvokeKind::SysCall => {
let callee_id =
mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)
mast_forest_builder.ensure_external(mast_root)
});

let syscall_node =
MastNode::new_syscall(callee_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(syscall_node)
mast_forest_builder.ensure_syscall(callee_id)
}
}
};
Expand All @@ -134,7 +128,7 @@ impl Assembler {
&self,
mast_forest_builder: &mut MastForestBuilder,
) -> Result<Option<MastNodeId>, AssemblyError> {
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn);
let dyn_node_id = mast_forest_builder.ensure_dynexec();

Ok(Some(dyn_node_id))
}
Expand All @@ -145,10 +139,8 @@ impl Assembler {
mast_forest_builder: &mut MastForestBuilder,
) -> Result<Option<MastNodeId>, AssemblyError> {
let dyn_call_node_id = {
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn);
let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest());

mast_forest_builder.ensure_node(dyn_call_node)
let dyn_node_id = mast_forest_builder.ensure_dynexec();
mast_forest_builder.ensure_call(dyn_node_id)
};

Ok(Some(dyn_call_node_id))
Expand Down
53 changes: 52 additions & 1 deletion assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use core::ops::Index;

use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use vm_core::{
crypto::hash::RpoDigest,
mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode},
DecoratorList, Operation,
};

/// Builder for a [`MastForest`].
Expand Down Expand Up @@ -44,7 +46,7 @@ impl MastForestBuilder {
/// If a [`MastNode`] which is equal to the current node was previously added, the previously
/// returned [`MastNodeId`] will be returned. This enforces this invariant that equal
/// [`MastNode`]s have equal [`MastNodeId`]s.
pub fn ensure_node(&mut self, node: impl Into<MastNode>) -> MastNodeId {
fn ensure_node(&mut self, node: impl Into<MastNode>) -> MastNodeId {
let node = node.into();
let node_digest = node.digest();

Expand All @@ -59,6 +61,55 @@ impl MastForestBuilder {
}
}

/// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_block(
&mut self,
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> MastNodeId {
match decorators {
Some(decorators) => {
self.ensure_node(MastNode::new_basic_block_with_decorators(operations, decorators))
}
None => self.ensure_node(MastNode::new_basic_block(operations)),
}
}

/// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_join(&mut self, left_child: MastNodeId, right_child: MastNodeId) -> MastNodeId {
self.ensure_node(MastNode::new_join(left_child, right_child, &self.mast_forest))
}

/// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_split(&mut self, if_branch: MastNodeId, else_branch: MastNodeId) -> MastNodeId {
self.ensure_node(MastNode::new_split(if_branch, else_branch, &self.mast_forest))
}

/// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_loop(&mut self, body: MastNodeId) -> MastNodeId {
self.ensure_node(MastNode::new_loop(body, &self.mast_forest))
}

/// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_call(&mut self, callee: MastNodeId) -> MastNodeId {
self.ensure_node(MastNode::new_call(callee, &self.mast_forest))
}

/// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_syscall(&mut self, callee: MastNodeId) -> MastNodeId {
self.ensure_node(MastNode::new_syscall(callee, &self.mast_forest))
}

/// Adds a dynexec node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_dynexec(&mut self) -> MastNodeId {
self.ensure_node(MastNode::new_dynexec())
}

/// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_external(&mut self, mast_root: RpoDigest) -> MastNodeId {
self.ensure_node(MastNode::new_external(mast_root))
}

/// Marks the given [`MastNodeId`] as being the root of a procedure.
pub fn make_root(&mut self, new_root_id: MastNodeId) {
self.mast_forest.make_root(new_root_id)
Expand Down
20 changes: 5 additions & 15 deletions assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use mast_forest_builder::MastForestBuilder;
use vm_core::{
mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode},
mast::{MastForest, MastNodeId, MerkleTreeNode},
Decorator, DecoratorList, Kernel, Operation, Program,
};

Expand Down Expand Up @@ -790,12 +790,7 @@ impl Assembler {
let else_blk =
self.compile_body(else_blk.iter(), context, None, mast_forest_builder)?;

let split_node_id = {
let split_node =
MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest());

mast_forest_builder.ensure_node(split_node)
};
let split_node_id = mast_forest_builder.ensure_split(then_blk, else_blk);
mast_node_ids.push(split_node_id);
}

Expand Down Expand Up @@ -824,11 +819,7 @@ impl Assembler {
let loop_body_node_id =
self.compile_body(body.iter(), context, None, mast_forest_builder)?;

let loop_node_id = {
let loop_node =
MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(loop_node)
};
let loop_node_id = mast_forest_builder.ensure_loop(loop_body_node_id);
mast_node_ids.push(loop_node_id);
}
}
Expand All @@ -839,7 +830,7 @@ impl Assembler {
}

Ok(if mast_node_ids.is_empty() {
mast_forest_builder.ensure_node(vec![Operation::Noop])
mast_forest_builder.ensure_block(vec![Operation::Noop], None)
} else {
combine_mast_node_ids(mast_node_ids, mast_forest_builder)
})
Expand Down Expand Up @@ -899,8 +890,7 @@ fn combine_mast_node_ids(
while let (Some(left), Some(right)) =
(source_mast_node_iter.next(), source_mast_node_iter.next())
{
let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest());
let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node);
let join_mast_node_id = mast_forest_builder.ensure_join(left, right);

mast_node_ids.push(join_mast_node_id);
}
Expand Down
90 changes: 31 additions & 59 deletions assembly/src/assembler/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,10 @@ fn nested_blocks() {
// contains the MAST nodes for the kernel after a call to
// `Assembler::with_kernel_from_module()`.
let syscall_foo_node_id = {
let kernel_foo_node_id = expected_mast_forest_builder.ensure_node(vec![Operation::Add]);
let kernel_foo_node_id =
expected_mast_forest_builder.ensure_block(vec![Operation::Add], None);

let syscall_node =
MastNode::new_syscall(kernel_foo_node_id, expected_mast_forest_builder.forest());
expected_mast_forest_builder.ensure_node(syscall_node)
expected_mast_forest_builder.ensure_syscall(kernel_foo_node_id)
};

let program = r#"
Expand Down Expand Up @@ -129,74 +128,47 @@ fn nested_blocks() {
let exec_bar_node_id = {
// bar procedure
let basic_block_1_id =
expected_mast_forest_builder.ensure_node(vec![Operation::Push(17_u32.into())]);
expected_mast_forest_builder.ensure_block(vec![Operation::Push(17_u32.into())], None);

// Basic block representing the `foo` procedure
let basic_block_2_id =
expected_mast_forest_builder.ensure_node(vec![Operation::Push(19_u32.into())]);

let join_node = MastNode::new_join(
basic_block_1_id,
basic_block_2_id,
expected_mast_forest_builder.forest(),
);
expected_mast_forest_builder.ensure_node(join_node)
};
expected_mast_forest_builder.ensure_block(vec![Operation::Push(19_u32.into())], None);

let exec_foo_bar_baz_node_id = {
// basic block representing foo::bar.baz procedure
expected_mast_forest_builder.ensure_node(vec![Operation::Push(29_u32.into())])
expected_mast_forest_builder.ensure_join(basic_block_1_id, basic_block_2_id)
};

let before = expected_mast_forest_builder.ensure_node(vec![Operation::Push(2u32.into())]);
// basic block representing foo::bar.baz procedure
let exec_foo_bar_baz_node_id =
expected_mast_forest_builder.ensure_block(vec![Operation::Push(29_u32.into())], None);

let r#true1 = expected_mast_forest_builder.ensure_node(vec![Operation::Push(3u32.into())]);
let r#false1 = expected_mast_forest_builder.ensure_node(vec![Operation::Push(5u32.into())]);
let r#if1 = {
let r#if_node =
MastNode::new_split(r#true1, r#false1, expected_mast_forest_builder.forest());
expected_mast_forest_builder.ensure_node(r#if_node)
};
let before =
expected_mast_forest_builder.ensure_block(vec![Operation::Push(2u32.into())], None);

let r#true3 = expected_mast_forest_builder.ensure_node(vec![Operation::Push(7u32.into())]);
let r#false3 = expected_mast_forest_builder.ensure_node(vec![Operation::Push(11u32.into())]);
let r#true2 = {
let r#if_node =
MastNode::new_split(r#true3, r#false3, expected_mast_forest_builder.forest());
expected_mast_forest_builder.ensure_node(r#if_node)
};
let r#true1 =
expected_mast_forest_builder.ensure_block(vec![Operation::Push(3u32.into())], None);
let r#false1 =
expected_mast_forest_builder.ensure_block(vec![Operation::Push(5u32.into())], None);
let r#if1 = expected_mast_forest_builder.ensure_split(r#true1, r#false1);

let r#true3 =
expected_mast_forest_builder.ensure_block(vec![Operation::Push(7u32.into())], None);
let r#false3 =
expected_mast_forest_builder.ensure_block(vec![Operation::Push(11u32.into())], None);
let r#true2 = expected_mast_forest_builder.ensure_split(r#true3, r#false3);

let r#while = {
let push_basic_block_id =
expected_mast_forest_builder.ensure_node(vec![Operation::Push(23u32.into())]);
let body_node_id = {
let body_node = MastNode::new_join(
exec_bar_node_id,
push_basic_block_id,
expected_mast_forest_builder.forest(),
);

expected_mast_forest_builder.ensure_node(body_node)
};

let loop_node = MastNode::new_loop(body_node_id, expected_mast_forest_builder.forest());
expected_mast_forest_builder.ensure_node(loop_node)
expected_mast_forest_builder.ensure_block(vec![Operation::Push(23u32.into())], None);
let body_node_id =
expected_mast_forest_builder.ensure_join(exec_bar_node_id, push_basic_block_id);

expected_mast_forest_builder.ensure_loop(body_node_id)
};
let push_13_basic_block_id =
expected_mast_forest_builder.ensure_node(vec![Operation::Push(13u32.into())]);

let r#false2 = {
let node = MastNode::new_join(
push_13_basic_block_id,
r#while,
expected_mast_forest_builder.forest(),
);
expected_mast_forest_builder.ensure_node(node)
};
let nested = {
let node = MastNode::new_split(r#true2, r#false2, expected_mast_forest_builder.forest());
expected_mast_forest_builder.ensure_node(node)
};
expected_mast_forest_builder.ensure_block(vec![Operation::Push(13u32.into())], None);

let r#false2 = expected_mast_forest_builder.ensure_join(push_13_basic_block_id, r#while);
let nested = expected_mast_forest_builder.ensure_split(r#true2, r#false2);

let combined_node_id = combine_mast_node_ids(
vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id],
Expand Down

0 comments on commit c53df25

Please sign in to comment.