Skip to content

Commit

Permalink
perf: Use recurse_or_return in inner-most loop of MerkleRoot snippet
Browse files Browse the repository at this point in the history
Save another 11 % of rows in the processor table for the relevant
problem size. This is close to being the end-of-the-line for efficient
Merkle root calculation, I believe. Even fewer lines could be used if
the leaf number is known, by using the snippet for statically-known leaf
counts, but you pay a big price for program size (and thus hash table
row count) if you do that.

This completes the 1st optimization mentioned in #106.
  • Loading branch information
Sword-Smith committed Jun 11, 2024
1 parent f27f67e commit c7dfb8b
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 71 deletions.
8 changes: 4 additions & 4 deletions tasm-lib/benchmarks/tasmlib_hashing_merkle_root.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
{
"name": "tasmlib_hashing_merkle_root",
"benchmark_result": {
"clock_cycle_count": 9023,
"clock_cycle_count": 6961,
"hash_table_height": 3156,
"u32_table_height": 93,
"op_stack_table_height": 15570,
"op_stack_table_height": 13562,
"ram_table_height": 7673
},
"case": "CommonCase"
},
{
"name": "tasmlib_hashing_merkle_root",
"benchmark_result": {
"clock_cycle_count": 17759,
"clock_cycle_count": 13647,
"hash_table_height": 6228,
"u32_table_height": 110,
"op_stack_table_height": 30952,
"op_stack_table_height": 26900,
"ram_table_height": 15353
},
"case": "WorstCase"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
{
"name": "tasmlib_hashing_merkle_root_from_xfes_generic",
"benchmark_result": {
"clock_cycle_count": 9313,
"clock_cycle_count": 8277,
"hash_table_height": 3228,
"u32_table_height": 102,
"op_stack_table_height": 14580,
"op_stack_table_height": 13592,
"ram_table_height": 6653
},
"case": "CommonCase"
},
{
"name": "tasmlib_hashing_merkle_root_from_xfes_generic",
"benchmark_result": {
"clock_cycle_count": 18305,
"clock_cycle_count": 16243,
"hash_table_height": 6300,
"u32_table_height": 120,
"op_stack_table_height": 28938,
"op_stack_table_height": 26930,
"ram_table_height": 13309
},
"case": "WorstCase"
Expand Down
12 changes: 6 additions & 6 deletions tasm-lib/benchmarks/tasmlib_verifier_fri_verify.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
{
"name": "tasmlib_verifier_fri_verify",
"benchmark_result": {
"clock_cycle_count": 49813,
"hash_table_height": 14658,
"clock_cycle_count": 47754,
"hash_table_height": 14664,
"u32_table_height": 11448,
"op_stack_table_height": 49875,
"op_stack_table_height": 47867,
"ram_table_height": 17532
},
"case": "CommonCase"
},
{
"name": "tasmlib_verifier_fri_verify",
"benchmark_result": {
"clock_cycle_count": 49813,
"hash_table_height": 14658,
"clock_cycle_count": 47754,
"hash_table_height": 14664,
"u32_table_height": 11060,
"op_stack_table_height": 49875,
"op_stack_table_height": 47867,
"ram_table_height": 17532
},
"case": "WorstCase"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
{
"name": "tasmlib_verifier_stark_verify_inner_padded_height_256_fri_exp_4",
"benchmark_result": {
"clock_cycle_count": 184750,
"hash_table_height": 125245,
"u32_table_height": 15011,
"op_stack_table_height": 172612,
"clock_cycle_count": 182691,
"hash_table_height": 125251,
"u32_table_height": 15039,
"op_stack_table_height": 170604,
"ram_table_height": 274596
},
"case": "CommonCase"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
{
"name": "tasmlib_verifier_stark_verify_inner_padded_height_512_fri_exp_4",
"benchmark_result": {
"clock_cycle_count": 193497,
"hash_table_height": 132937,
"u32_table_height": 18254,
"op_stack_table_height": 178736,
"clock_cycle_count": 191438,
"hash_table_height": 132943,
"u32_table_height": 18488,
"op_stack_table_height": 176728,
"ram_table_height": 275774
},
"case": "CommonCase"
Expand Down
99 changes: 50 additions & 49 deletions tasm-lib/src/hashing/merkle_root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,44 +111,49 @@ impl BasicSnippet for MerkleRoot {
return
// _ current_len *next_level *current_level

// What is the stop-condition for `*current_level`?
// It must be `*curr - current_length * DIGEST_LENGTH`
dup 0
/*Update `current_len` */
swap 2
log_2_floor
push -1
add
push 2
pow
swap 2
// _ (current_len / 2) *next_level *current_level
// _ current_len' *next_level *current_level

// What is the stop-condition for `*next_level`?
// It must be `*next_level - current_len / 2 * DIGEST_LENGTH`
dup 1
dup 3
push {-(DIGEST_LENGTH as isize)}
mul
add
// _ current_len *next_level *current_level *current_level_stop
// _ current_len' *next_level *current_level *next_level_stop

swap 1
// _ current_len *next_level *current_level_stop *current_level
// _ current_len' *next_level *next_level_stop *current_level

dup 2
swap 1
// _ current_len *next_level *current_level_stop *next_level *current_level
// _ current_len' *next_level *next_level_stop *next_level *current_level

push 0
push 0
push 0
push 0

// _ current_len' *next_level *next_elem_stop *next_elem *curr_elem 0 0 0 0
call {calculate_parent_digests}
// _ current_len *next_level *current_level_stop *next_level' *current_level_stop
// _ current_len' *next_level *next_elem_stop *next_elem_stop *curr_elem_stop 0 0 0 0

pop 5
pop 1
swap 1
pop 1
// _ current_len *next_level *next_level_next

/*Update `current_len` */
swap 2
log_2_floor
push -1
add
push 2
pow
swap 2
// _ (current_len / 2) *next_level *next_level'
// _ current_len' *next_level *next_level'
// _ current_len' *next_level *next_elem_stop

/* Update `*current_level` based on `*next_level` */
swap 1
// _ (current_len / 2) *next_level' *next_level
// _ current_len' *next_level' *next_level

push {DIGEST_LENGTH - 1}
add
Expand All @@ -157,48 +162,40 @@ impl BasicSnippet for MerkleRoot {
recurse

// Populate the `*next` digest list
// START: _ *current_level_stop *next_last_elem_first_word *curr_last_word
// INVARIANT: _ *current_level_stop *next_elem *curr_elem
// END: _ *current_level_stop *next *current_level_stop
// INVARIANT: _ *next_elem_stop *next_elem *curr_elem 0 0 0 0
{calculate_parent_digests}:
dup 2
dup 1
eq
skiz
return
// _ *curr *next_elem *curr_elem[n]

dup 0
dup 4
read_mem {DIGEST_LENGTH}
read_mem {DIGEST_LENGTH}
// _ *curr *next_elem *curr_elem [right] [left] (*curr_elem[n] - 10)
// _ *curr *next_elem *curr_elem [right] [left] *curr_elem[n - 2]
// _ *curr *next_elem *curr_elem [right] [left] *curr_elem'
// _ *next_elem_stop *next_elem *curr_elem 0 0 0 0 [right] [left] (*curr_elem[n] - 10)
// _ *next_elem_stop *next_elem *curr_elem 0 0 0 0 [right] [left] *curr_elem[n - 2]
// _ *next_elem_stop *next_elem *curr_elem 0 0 0 0 [right] [left] *curr_elem'

swap 11
swap 15
pop 1
// _ *curr *next_elem *curr_elem' [right] [left]
// _ *next_elem_stop *next_elem *curr_elem' 0 0 0 0 [right] [left]

hash
// _ *curr *next_elem *curr_elem' [parent_digest]
// _ *next_elem_stop *next_elem *curr_elem' 0 0 0 0 [parent_digest]

dup 6
// _ *curr *next_elem *curr_elem' [parent_digest] *next_elem
dup 10
// _ *next_elem_stop *next_elem *curr_elem' 0 0 0 0 [parent_digest] *next_elem

write_mem {DIGEST_LENGTH}
// _ *curr *next_elem *curr_elem' (*next_elem + 5)
// _ *next_elem_stop *next_elem *curr_elem' 0 0 0 0 (*next_elem + 5)

push -10
add
// _ *curr *next_elem *curr_elem' (*next_elem - 5)
// _ *curr *next_elem *curr_elem' *next_elem[n-1]
// _ *curr *next_elem *curr_elem' *next_elem'
// _ *next_elem_stop *next_elem *curr_elem' 0 0 0 0 (*next_elem - 5)
// _ *next_elem_stop *next_elem *curr_elem' 0 0 0 0 *next_elem[n-1]
// _ *next_elem_stop *next_elem *curr_elem' 0 0 0 0 *next_elem'

swap 2
swap 6
pop 1
// _ *curr *next_elem' *curr_elem'
// _ *next_elem_stop *next_elem' *curr_elem' 0 0 0 0

recurse
recurse_or_return
)
}
}
Expand Down Expand Up @@ -246,13 +243,17 @@ impl Function for MerkleRoot {
}

fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
let height_0 = self.init_state(vec![Digest::default()], BFieldElement::zero());
let height_0_a = self.init_state(vec![Digest::default()], BFieldElement::zero());
let height_0_b = self.init_state(
vec![Digest::new([bfe!(6), bfe!(5), bfe!(4), bfe!(3), bfe!(2)])],
bfe!(1u64 << 44),
);
let height_1 = self.init_state(
vec![Digest::default(), Digest::default()],
BFieldElement::zero(),
);

vec![height_0, height_1]
vec![height_0_a, height_0_b, height_1]
}
}

Expand Down

0 comments on commit c7dfb8b

Please sign in to comment.