Skip to content

Commit

Permalink
feat(performance): Check sub operations against induction variables (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
vezenovm authored Feb 12, 2025
1 parent 5d427c8 commit 7cdce1f
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 35 deletions.
146 changes: 121 additions & 25 deletions compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ impl Loops {
continue;
};

context.hoist_loop_invariants(&loop_, pre_header);
context.current_pre_header = Some(pre_header);
context.hoist_loop_invariants(&loop_);
}

context.map_dependent_instructions();
Expand Down Expand Up @@ -89,15 +90,19 @@ struct LoopInvariantContext<'f> {
inserter: FunctionInserter<'f>,
defined_in_loop: HashSet<ValueId>,
loop_invariants: HashSet<ValueId>,
// Maps current loop induction variable -> fixed upper loop bound
// Maps current loop induction variable -> fixed lower and upper loop bound
// This map is expected to only ever contain a singular value.
// However, we store it in a map in order to match the definition of
// `outer_induction_variables` as both maps share checks for evaluating binary operations.
current_induction_variables: HashMap<ValueId, FieldElement>,
// Maps outer loop induction variable -> fixed upper loop bound
current_induction_variables: HashMap<ValueId, (FieldElement, FieldElement)>,
// Maps outer loop induction variable -> fixed lower and upper loop bound
// This will be used by inner loops to determine whether they
// have safe operations reliant upon an outer loop's maximum induction variable.
outer_induction_variables: HashMap<ValueId, FieldElement>,
outer_induction_variables: HashMap<ValueId, (FieldElement, FieldElement)>,
// This context struct processes runs across all loops.
// This stores the current loop's pre-header block.
// It is wrapped in an Option as our SSA `Id<T>` does not allow dummy values.
current_pre_header: Option<BasicBlockId>,
}

impl<'f> LoopInvariantContext<'f> {
Expand All @@ -108,20 +113,25 @@ impl<'f> LoopInvariantContext<'f> {
loop_invariants: HashSet::default(),
current_induction_variables: HashMap::default(),
outer_induction_variables: HashMap::default(),
current_pre_header: None,
}
}

fn hoist_loop_invariants(&mut self, loop_: &Loop, pre_header: BasicBlockId) {
fn pre_header(&self) -> BasicBlockId {
self.current_pre_header.expect("ICE: Pre-header block should have been set")
}

fn hoist_loop_invariants(&mut self, loop_: &Loop) {
self.set_values_defined_in_loop(loop_);

for block in loop_.blocks.iter() {
for instruction_id in self.inserter.function.dfg[*block].take_instructions() {
self.transform_to_unchecked_from_upper_bound(instruction_id);
self.transform_to_unchecked_from_loop_bounds(instruction_id);

let hoist_invariant = self.can_hoist_invariant(instruction_id);

if hoist_invariant {
self.inserter.push_instruction(instruction_id, pre_header);
self.inserter.push_instruction(instruction_id, self.pre_header());

// If we are hoisting a MakeArray instruction,
// we need to issue an extra inc_rc in case they are mutated afterward.
Expand Down Expand Up @@ -219,7 +229,7 @@ impl<'f> LoopInvariantContext<'f> {
let can_be_deduplicated = instruction.can_be_deduplicated(self.inserter.function, false)
|| matches!(instruction, Instruction::MakeArray { .. })
|| matches!(instruction, Instruction::Binary(_))
|| self.can_be_deduplicated_from_upper_bound(&instruction);
|| self.can_be_deduplicated_from_loop_bound(&instruction);

is_loop_invariant && can_be_deduplicated
}
Expand All @@ -230,31 +240,33 @@ impl<'f> LoopInvariantContext<'f> {
/// When within the current loop, the known upper bound can be used to simplify instructions,
/// such as transforming a checked add to an unchecked add.
fn set_induction_var_bounds(&mut self, loop_: &Loop, current_loop: bool) {
let upper_bound = loop_.get_const_upper_bound(self.inserter.function);
if let Some(upper_bound) = upper_bound {
let bounds = loop_.get_const_bounds(self.inserter.function, self.pre_header());
if let Some((lower_bound, upper_bound)) = bounds {
let induction_variable = loop_.get_induction_variable(self.inserter.function);
let induction_variable = self.inserter.resolve(induction_variable);
if current_loop {
self.current_induction_variables.insert(induction_variable, upper_bound);
self.current_induction_variables
.insert(induction_variable, (lower_bound, upper_bound));
} else {
self.outer_induction_variables.insert(induction_variable, upper_bound);
self.outer_induction_variables
.insert(induction_variable, (lower_bound, upper_bound));
}
}
}

/// Certain instructions can take advantage of that our induction variable has a fixed maximum.
/// Certain instructions can take advantage of that our induction variable has a fixed minimum/maximum.
///
/// For example, an array access can usually only be safely deduplicated when we have a constant
/// index that is below the length of the array.
/// Checking an array get where the index is the loop's induction variable on its own
/// would determine that the instruction is not safe for hoisting.
/// However, if we know that the induction variable's upper bound will always be in bounds of the array
/// we can safely hoist the array access.
fn can_be_deduplicated_from_upper_bound(&self, instruction: &Instruction) -> bool {
fn can_be_deduplicated_from_loop_bound(&self, instruction: &Instruction) -> bool {
match instruction {
Instruction::ArrayGet { array, index } => {
let array_typ = self.inserter.function.dfg.type_of_value(*array);
let upper_bound = self.outer_induction_variables.get(index);
let upper_bound = self.outer_induction_variables.get(index).map(|bounds| bounds.1);
if let (Type::Array(_, len), Some(upper_bound)) = (array_typ, upper_bound) {
upper_bound.to_u128() <= len.into()
} else {
Expand All @@ -268,15 +280,15 @@ impl<'f> LoopInvariantContext<'f> {
}
}

/// Binary operations can take advantage of that our induction variable has a fixed maximum,
/// Binary operations can take advantage of that our induction variable has a fixed minimum/maximum,
/// to be transformed from a checked operation to an unchecked operation.
///
/// Checked operations require more bytecode and thus we aim to minimize their usage wherever possible.
///
/// If one side of a binary operation is a constant and the other is an induction variable
/// For example, if one side of an add/mul operation is a constant and the other is an induction variable
/// with a known upper bound, we know whether that binary operation will ever overflow.
/// If we determine that an overflow is not possible we can convert the checked operation to unchecked.
fn transform_to_unchecked_from_upper_bound(&mut self, instruction_id: InstructionId) {
fn transform_to_unchecked_from_loop_bounds(&mut self, instruction_id: InstructionId) {
let Instruction::Binary(binary) = &self.inserter.function.dfg[instruction_id] else {
return;
};
Expand All @@ -292,16 +304,19 @@ impl<'f> LoopInvariantContext<'f> {
};
}

/// Checks whether a binary operation can be evaluated using the upper bound of the given loop induction variables.
/// Checks whether a binary operation can be evaluated using the bounds of a given loop induction variables.
///
/// If it cannot be evaluated, it means that we either have a dynamic loop bound or
/// that the operation can potentially overflow at the upper loop bound.
/// that the operation can potentially overflow during a given loop iteration.
fn can_evaluate_binary_op(
&self,
binary: &Binary,
induction_vars: &HashMap<ValueId, FieldElement>,
induction_vars: &HashMap<ValueId, (FieldElement, FieldElement)>,
) -> bool {
if !matches!(binary.operator, BinaryOp::Add { .. } | BinaryOp::Mul { .. }) {
if !matches!(
binary.operator,
BinaryOp::Add { .. } | BinaryOp::Mul { .. } | BinaryOp::Sub { .. }
) {
return false;
}

Expand All @@ -315,8 +330,16 @@ impl<'f> LoopInvariantContext<'f> {
induction_vars.get(&binary.lhs),
induction_vars.get(&binary.rhs),
) {
(Some((lhs, _)), None, None, Some(upper_bound)) => (lhs, *upper_bound),
(None, Some((rhs, _)), Some(upper_bound), None) => (*upper_bound, rhs),
(Some((lhs, _)), None, None, Some((_, upper_bound))) => (lhs, *upper_bound),
(None, Some((rhs, _)), Some((lower_bound, upper_bound)), None) => {
if matches!(binary.operator, BinaryOp::Sub { .. }) {
// If we are subtracting and the induction variable is on the lhs,
// we want to check the induction variable lower bound.
(*lower_bound, rhs)
} else {
(*upper_bound, rhs)
}
}
_ => return false,
};

Expand Down Expand Up @@ -804,4 +827,77 @@ mod test {
let ssa = ssa.loop_invariant_code_motion();
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn do_not_transform_unsafe_sub_to_unchecked() {
// This test is identical to `simple_loop_invariant_code_motion`, except this test
// uses a checked sub in `b3`.
// We want to make sure that our sub operation has the induction variable (`v2`) on the lhs.
// The induction variable `v2` is placed on the lhs of the sub operation
// to test that we are checking against the loop's lower bound
// rather than the upper bound (add/mul only check against the upper bound).
let src = "
brillig(inline) fn main f0 {
b0(v0: u32, v1: u32):
jmp b1(u32 0)
b1(v2: u32):
v5 = lt v2, u32 4
jmpif v5 then: b3, else: b2
b2():
return
b3():
v7 = sub v2, u32 1
jmp b1(v7)
}
";

let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.loop_invariant_code_motion();
assert_normalized_ssa_equals(ssa, src);
}

#[test]
fn transform_safe_sub_to_unchecked() {
// This test is identical to `do_not_transform_unsafe_sub_to_unchecked`, except the loop
// in this test starts with a lower bound of `1`.
let src = "
brillig(inline) fn main f0 {
b0(v0: u32, v1: u32):
jmp b1(u32 1)
b1(v2: u32):
v5 = lt v2, u32 4
jmpif v5 then: b3, else: b2
b2():
return
b3():
v6 = mul v0, v1
constrain v6 == u32 6
v8 = sub v2, u32 1
jmp b1(v8)
}
";

let ssa = Ssa::from_str(src).unwrap();

// `v8 = sub v2, u32 1` in b3 should now be `v9 = unchecked_sub v2, u32 1` in b3
let expected = "
brillig(inline) fn main f0 {
b0(v0: u32, v1: u32):
v3 = mul v0, v1
jmp b1(u32 1)
b1(v2: u32):
v6 = lt v2, u32 4
jmpif v6 then: b3, else: b2
b2():
return
b3():
constrain v3 == u32 6
v8 = unchecked_sub v2, u32 1
jmp b1(v8)
}
";

let ssa = ssa.loop_invariant_code_motion();
assert_normalized_ssa_equals(ssa, expected);
}
}
22 changes: 12 additions & 10 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,8 @@ impl Loop {
fn get_const_lower_bound(
&self,
function: &Function,
cfg: &ControlFlowGraph,
pre_header: BasicBlockId,
) -> Option<FieldElement> {
let pre_header = self.get_pre_header(function, cfg).ok()?;
let jump_value = get_induction_variable(function, pre_header).ok()?;
function.dfg.get_numeric_constant(jump_value)
}
Expand All @@ -320,7 +319,7 @@ impl Loop {
/// v5 = lt v1, u32 4 // Upper bound
/// jmpif v5 then: b3, else: b2
/// ```
pub(super) fn get_const_upper_bound(&self, function: &Function) -> Option<FieldElement> {
fn get_const_upper_bound(&self, function: &Function) -> Option<FieldElement> {
let block = &function.dfg[self.header];
let instructions = block.instructions();
if instructions.is_empty() {
Expand Down Expand Up @@ -351,12 +350,12 @@ impl Loop {
}

/// Get the lower and upper bounds of the loop if both are constant numeric values.
fn get_const_bounds(
pub(super) fn get_const_bounds(
&self,
function: &Function,
cfg: &ControlFlowGraph,
pre_header: BasicBlockId,
) -> Option<(FieldElement, FieldElement)> {
let lower = self.get_const_lower_bound(function, cfg)?;
let lower = self.get_const_lower_bound(function, pre_header)?;
let upper = self.get_const_upper_bound(function)?;
Some((lower, upper))
}
Expand Down Expand Up @@ -665,7 +664,8 @@ impl Loop {
function: &Function,
cfg: &ControlFlowGraph,
) -> Option<BoilerplateStats> {
let (lower, upper) = self.get_const_bounds(function, cfg)?;
let pre_header = self.get_pre_header(function, cfg).ok()?;
let (lower, upper) = self.get_const_bounds(function, pre_header)?;
let lower = lower.try_to_u64()?;
let upper = upper.try_to_u64()?;
let refs = self.find_pre_header_reference_values(function, cfg)?;
Expand Down Expand Up @@ -1143,9 +1143,11 @@ mod tests {
let loops = Loops::find_all(function);
assert_eq!(loops.yet_to_unroll.len(), 1);

let (lower, upper) = loops.yet_to_unroll[0]
.get_const_bounds(function, &loops.cfg)
.expect("bounds are numeric const");
let loop_ = &loops.yet_to_unroll[0];
let pre_header =
loop_.get_pre_header(function, &loops.cfg).expect("Should have a pre_header");
let (lower, upper) =
loop_.get_const_bounds(function, pre_header).expect("bounds are numeric const");

assert_eq!(lower, FieldElement::from(0u32));
assert_eq!(upper, FieldElement::from(4u32));
Expand Down

0 comments on commit 7cdce1f

Please sign in to comment.