Skip to content

Commit

Permalink
fix(brillig): Brillig entry point analysis and function specializatio…
Browse files Browse the repository at this point in the history
…n through duplication (#7277)

Co-authored-by: Tom French <tom@tomfren.ch>
Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 18, 2025
1 parent 1dc6a8b commit 119bf62
Show file tree
Hide file tree
Showing 16 changed files with 810 additions and 155 deletions.
4 changes: 2 additions & 2 deletions .github/benchmark_projects.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ projects:
path: noir-projects/noir-protocol-circuits/crates/private-kernel-reset
num_runs: 5
timeout: 250
compilation-timeout: 7
compilation-timeout: 8
execution-timeout: 0.35
compilation-memory-limit: 750
execution-memory-limit: 300
Expand Down Expand Up @@ -74,7 +74,7 @@ projects:
num_runs: 1
timeout: 60
compilation-timeout: 100
execution-timeout: 35
execution-timeout: 40
compilation-memory-limit: 7000
execution-memory-limit: 1500
rollup-merge:
Expand Down
197 changes: 52 additions & 145 deletions compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::collections::BTreeMap;
use std::collections::{BTreeMap, BTreeSet};

use acvm::FieldElement;
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};

use super::{
BrilligArtifact, BrilligBlock, BrilligVariable, Function, FunctionContext, Label, ValueId,
};
use crate::brillig::{
brillig_ir::BrilligContext, called_functions_vec, Brillig, BrilligOptions, DataFlowGraph,
FunctionId, Instruction, Value,
use crate::{
brillig::{brillig_ir::BrilligContext, Brillig, BrilligOptions, DataFlowGraph, FunctionId},
ssa::opt::brillig_entry_points::{build_inner_call_to_entry_points, get_brillig_entry_points},
};

/// Context structure for generating Brillig globals
Expand All @@ -24,12 +24,12 @@ pub(crate) struct BrilligGlobals {
/// Maps a Brillig entry point to all functions called in that entry point.
/// This includes any nested calls as well, as we want to be able to associate
/// any Brillig function with the appropriate global allocations.
brillig_entry_points: HashMap<FunctionId, HashSet<FunctionId>>,
brillig_entry_points: BTreeMap<FunctionId, BTreeSet<FunctionId>>,

/// Maps an inner call to its Brillig entry point
/// This is simply used to simplify fetching global allocations when compiling
/// individual Brillig functions.
inner_call_to_entry_point: HashMap<FunctionId, Vec<FunctionId>>,
inner_call_to_entry_point: HashMap<FunctionId, BTreeSet<FunctionId>>,
/// Final map that associated an entry point with its Brillig global allocations
entry_point_globals_map: HashMap<FunctionId, SsaToBrilligGlobals>,
}
Expand All @@ -43,106 +43,27 @@ impl BrilligGlobals {
mut used_globals: HashMap<FunctionId, HashSet<ValueId>>,
main_id: FunctionId,
) -> Self {
let mut brillig_entry_points = HashMap::default();
let acir_functions = functions.iter().filter(|(_, func)| func.runtime().is_acir());
for (_, function) in acir_functions {
for block_id in function.reachable_blocks() {
for instruction_id in function.dfg[block_id].instructions() {
let instruction = &function.dfg[*instruction_id];
let Instruction::Call { func: func_id, arguments: _ } = instruction else {
continue;
};

let func_value = &function.dfg[*func_id];
let Value::Function(func_id) = func_value else { continue };

let called_function = &functions[func_id];
if called_function.runtime().is_acir() {
continue;
}

// We have now found a Brillig entry point.
// Let's recursively build a call graph to determine any functions
// whose parent is this entry point and any globals used in those internal calls.
brillig_entry_points.insert(*func_id, HashSet::default());
Self::mark_entry_points_calls_recursive(
functions,
*func_id,
called_function,
&mut used_globals,
&mut brillig_entry_points,
im::HashSet::new(),
);
}
let brillig_entry_points = get_brillig_entry_points(functions, main_id);

// Mark any globals used in a Brillig entry point.
// Using the information collected we can determine which globals
// an entry point must initialize.
for (entry_point, entry_point_inner_calls) in brillig_entry_points.iter() {
for inner_call in entry_point_inner_calls.iter() {
let inner_globals = used_globals
.get(inner_call)
.expect("Should have a slot for each function")
.clone();
used_globals
.get_mut(entry_point)
.expect("ICE: should have func")
.extend(inner_globals);
}
}

// If main has been marked as Brillig, it is itself an entry point.
// Run the same analysis from above on main.
let main_func = &functions[&main_id];
if main_func.runtime().is_brillig() {
brillig_entry_points.insert(main_id, HashSet::default());
Self::mark_entry_points_calls_recursive(
functions,
main_id,
main_func,
&mut used_globals,
&mut brillig_entry_points,
im::HashSet::new(),
);
}
let inner_call_to_entry_point = build_inner_call_to_entry_points(&brillig_entry_points);

// NB: Temporary fix to override entry point analysis
let merged_set =
used_globals.values().flat_map(|set| set.iter().copied()).collect::<HashSet<_>>();
for set in used_globals.values_mut() {
*set = merged_set.clone();
}

Self { used_globals, brillig_entry_points, ..Default::default() }
}

/// Recursively mark any functions called in an entry point as well as
/// any globals used in those functions.
/// Using the information collected we can determine which globals
/// an entry point must initialize.
fn mark_entry_points_calls_recursive(
functions: &BTreeMap<FunctionId, Function>,
entry_point: FunctionId,
called_function: &Function,
used_globals: &mut HashMap<FunctionId, HashSet<ValueId>>,
brillig_entry_points: &mut HashMap<FunctionId, HashSet<FunctionId>>,
mut explored_functions: im::HashSet<FunctionId>,
) {
if explored_functions.insert(called_function.id()).is_some() {
return;
}

let inner_calls = called_functions_vec(called_function).into_iter().collect::<HashSet<_>>();

for inner_call in inner_calls {
let inner_globals = used_globals
.get(&inner_call)
.expect("Should have a slot for each function")
.clone();
used_globals
.get_mut(&entry_point)
.expect("ICE: should have func")
.extend(inner_globals);

if let Some(inner_calls) = brillig_entry_points.get_mut(&entry_point) {
inner_calls.insert(inner_call);
}

Self::mark_entry_points_calls_recursive(
functions,
entry_point,
&functions[&inner_call],
used_globals,
brillig_entry_points,
explored_functions.clone(),
);
}
Self { used_globals, brillig_entry_points, inner_call_to_entry_point, ..Default::default() }
}

pub(crate) fn declare_globals(
Expand All @@ -151,18 +72,11 @@ impl BrilligGlobals {
brillig: &mut Brillig,
options: &BrilligOptions,
) {
// Map for fetching the correct entry point globals when compiling any function
let mut inner_call_to_entry_point: HashMap<FunctionId, Vec<FunctionId>> =
HashMap::default();
let mut entry_point_globals_map = HashMap::default();
// We only need to generate globals for entry points
for (entry_point, entry_point_inner_calls) in self.brillig_entry_points.iter() {
for (entry_point, _) in self.brillig_entry_points.iter() {
let entry_point = *entry_point;

for inner_call in entry_point_inner_calls {
inner_call_to_entry_point.entry(*inner_call).or_default().push(entry_point);
}

let used_globals = self.used_globals.remove(&entry_point).unwrap_or_default();
let (artifact, brillig_globals, globals_size) =
convert_ssa_globals(options, globals_dfg, &used_globals, entry_point);
Expand All @@ -173,46 +87,41 @@ impl BrilligGlobals {
brillig.globals_memory_size.insert(entry_point, globals_size);
}

self.inner_call_to_entry_point = inner_call_to_entry_point;
self.entry_point_globals_map = entry_point_globals_map;
}

/// Fetch the global allocations that can possibly be accessed
/// by any given Brillig function (non-entry point or entry point).
/// The allocations available to a function are determined by its entry point.
/// For a given function id input, this function will search for that function's
/// entry point (or multiple entry points) and fetch the global allocations
/// associated with those entry points.
/// entry point and fetch the global allocations associated with that entry point.
/// These allocations can then be used when compiling the Brillig function
/// and resolving global variables.
pub(crate) fn get_brillig_globals(
&self,
brillig_function_id: FunctionId,
) -> SsaToBrilligGlobals {
let entry_points = self.inner_call_to_entry_point.get(&brillig_function_id);
) -> Option<&SsaToBrilligGlobals> {
if let Some(globals) = self.entry_point_globals_map.get(&brillig_function_id) {
// Check whether `brillig_function_id` is itself an entry point.
// If so, return the global allocations directly from `self.entry_point_globals_map`.
return Some(globals);
}

let mut globals_allocations = HashMap::default();
if let Some(entry_points) = entry_points {
// A Brillig function is used by multiple entry points. Fetch both globals allocations
// in case one is used by the internal call.
let entry_point_allocations = entry_points
.iter()
.flat_map(|entry_point| self.entry_point_globals_map.get(entry_point))
.collect::<Vec<_>>();
for map in entry_point_allocations {
globals_allocations.extend(map);
}
} else if let Some(globals) = self.entry_point_globals_map.get(&brillig_function_id) {
// If there is no mapping from an inner call to an entry point, that means `brillig_function_id`
// is itself an entry point and we can fetch the global allocations directly from `self.entry_point_globals_map`.
// vec![globals]
globals_allocations.extend(globals);
} else {
let entry_points = self.inner_call_to_entry_point.get(&brillig_function_id);
let Some(entry_points) = entry_points else {
unreachable!(
"ICE: Expected global allocation to be set for function {brillig_function_id}"
);
};

// Sanity check: We should have guaranteed earlier that an inner call has only a single entry point
assert_eq!(entry_points.len(), 1, "{brillig_function_id} has multiple entry points");
let entry_point = entry_points.first().expect("ICE: Inner call should have an entry point");
if let Some(globals) = self.entry_point_globals_map.get(entry_point) {
return Some(globals);
}
globals_allocations

None
}
}

Expand Down Expand Up @@ -312,11 +221,10 @@ mod tests {
if func_id.to_u32() == 1 {
assert_eq!(
artifact.byte_code.len(),
2,
1,
"Expected just a `Return`, but got more than a single opcode"
);
// TODO: Bring this back (https://github.com/noir-lang/noir/issues/7306)
// assert!(matches!(&artifact.byte_code[0], Opcode::Return));
assert!(matches!(&artifact.byte_code[0], Opcode::Return));
} else if func_id.to_u32() == 2 {
assert_eq!(
artifact.byte_code.len(),
Expand Down Expand Up @@ -430,17 +338,16 @@ mod tests {
if func_id.to_u32() == 1 {
assert_eq!(
artifact.byte_code.len(),
30,
2,
"Expected enough opcodes to initialize the globals"
);
// TODO: Bring this back (https://github.com/noir-lang/noir/issues/7306)
// let Opcode::Const { destination, bit_size, value } = &artifact.byte_code[0] else {
// panic!("First opcode is expected to be `Const`");
// };
// assert_eq!(destination.unwrap_direct(), GlobalSpace::start());
// assert!(matches!(bit_size, BitSize::Field));
// assert_eq!(*value, FieldElement::from(1u128));
// assert!(matches!(&artifact.byte_code[1], Opcode::Return));
let Opcode::Const { destination, bit_size, value } = &artifact.byte_code[0] else {
panic!("First opcode is expected to be `Const`");
};
assert_eq!(destination.unwrap_direct(), GlobalSpace::start());
assert!(matches!(bit_size, BitSize::Field));
assert_eq!(*value, FieldElement::from(1u128));
assert!(matches!(&artifact.byte_code[1], Opcode::Return));
} else if func_id.to_u32() == 2 || func_id.to_u32() == 3 {
// We want the entry point which uses globals (f2) and the entry point which calls f2 function internally (f3 through f4)
// to have the same globals initialized.
Expand Down
11 changes: 6 additions & 5 deletions compiler/noirc_evaluator/src/brillig/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ use crate::ssa::{
ir::{
dfg::DataFlowGraph,
function::{Function, FunctionId},
instruction::Instruction,
value::{Value, ValueId},
value::ValueId,
},
opt::inlining::called_functions_vec,
ssa_gen::Ssa,
};
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
Expand Down Expand Up @@ -122,10 +120,13 @@ impl Ssa {
brillig_globals.declare_globals(&globals_dfg, &mut brillig, options);

for brillig_function_id in brillig_reachable_function_ids {
let globals_allocations = brillig_globals.get_brillig_globals(brillig_function_id);
let empty_allocations = HashMap::default();
let globals_allocations = brillig_globals
.get_brillig_globals(brillig_function_id)
.unwrap_or(&empty_allocations);

let func = &self.functions[&brillig_function_id];
brillig.compile(func, options, &globals_allocations);
brillig.compile(func, options, globals_allocations);
}

brillig
Expand Down
10 changes: 8 additions & 2 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ pub(crate) fn optimize_into_acir(
.run_pass(|ssa| ssa.fold_constants_with_brillig(&brillig), "Inlining Brillig Calls Inlining")
// It could happen that we inlined all calls to a given brillig function.
// In that case it's unused so we can remove it. This is what we check next.
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (3rd)")
.run_pass(Ssa::dead_instruction_elimination, "Dead Instruction Elimination (2nd)")
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (4th)")
.run_pass(Ssa::dead_instruction_elimination, "Dead Instruction Elimination (3rd)")
.finish();

if !options.skip_underconstrained_check {
Expand Down Expand Up @@ -217,6 +217,12 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result<Ss
.run_pass(Ssa::dead_instruction_elimination, "Dead Instruction Elimination (1st)")
.run_pass(Ssa::simplify_cfg, "Simplifying (3rd):")
.run_pass(Ssa::array_set_optimization, "Array Set Optimizations")
// The Brillig globals pass expected that we have the used globals map set for each function.
// The used globals map is determined during DIE, so we should duplicate entry points before a DIE pass run.
.run_pass(Ssa::brillig_entry_point_analysis, "Brillig Entry Point Analysis")
// Remove any potentially unnecessary duplication from the Brillig entry point analysis.
.run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (3rd)")
.run_pass(Ssa::dead_instruction_elimination, "Dead Instruction Elimination (2nd)")
.finish())
}

Expand Down
Loading

0 comments on commit 119bf62

Please sign in to comment.