Skip to content

Commit

Permalink
feat: Implement hint on uint256_mul_div_mod (lambdaclass#957)
Browse files Browse the repository at this point in the history
* Add normalize address hints

* Revert "Add normalize address hints"

This reverts commit 86077f2.

* Add uint256_mul_div_mod hint

* Add hint to match

* Expand uint256 integration test

* use u128::MAX

* Fix value

* Manage quotient & remainder as BigUint

* Add test for hint

* Add test for hint

* Add misc test

* Add changelog entry

* fmt

* fmt
  • Loading branch information
fmoletta authored and kariy committed Jun 23, 2023
1 parent 68585b3 commit abadc9d
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 2 deletions.
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,26 @@

#### Upcoming Changes

* Implement hint on `uint256_mul_div_mod`[#957](https://github.com/lambdaclass/cairo-rs/pull/957)

`BuiltinHintProcessor` now supports the following hint:

```python
a = (ids.a.high << 128) + ids.a.low
b = (ids.b.high << 128) + ids.b.low
div = (ids.div.high << 128) + ids.div.low
quotient, remainder = divmod(a * b, div)

ids.quotient_low.low = quotient & ((1 << 128) - 1)
ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)
ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)
ids.quotient_high.high = quotient >> 384
ids.remainder.low = remainder & ((1 << 128) - 1)
ids.remainder.high = remainder >> 128"
```

Used by the common library function `uint256_mul_div_mod`

* Move `Memory` into `MemorySegmentManager` [#830](https://github.com/lambdaclass/cairo-rs/pull/830)
* Structural changes:
* Remove `memory: Memory` field from `VirtualMachine`
Expand Down
19 changes: 19 additions & 0 deletions cairo_programs/uint256.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ from starkware.cairo.common.uint256 import (
uint256_signed_nn,
uint256_unsigned_div_rem,
uint256_mul,
uint256_mul_div_mod
)
from starkware.cairo.common.alloc import alloc

Expand Down Expand Up @@ -57,6 +58,24 @@ func main{range_check_ptr: felt}() {
assert b_quotient = Uint256(1, 0);
assert b_remainder = Uint256(340282366920938463463374607431768211377, 0);

let (a_quotient_low, a_quotient_high, a_remainder) = uint256_mul_div_mod(
Uint256(89, 72),
Uint256(3, 7),
Uint256(107, 114),
);
assert a_quotient_low = Uint256(143276786071974089879315624181797141668, 4);
assert a_quotient_high = Uint256(0, 0);
assert a_remainder = Uint256(322372768661941702228460154409043568767, 101);

let (b_quotient_low, b_quotient_high, b_remainder) = uint256_mul_div_mod(
Uint256(-3618502788666131213697322783095070105282824848410658236509717448704103809099, 2),
Uint256(1, 1),
Uint256(5, 2),
);
assert b_quotient_low = Uint256(170141183460469231731687303715884105688, 1);
assert b_quotient_high = Uint256(0, 0);
assert b_remainder = Uint256(170141183460469231731687303715884105854, 1);

let (mult_low_a, mult_high_a) = uint256_mul(Uint256(59, 2), Uint256(10, 0));
assert mult_low_a = Uint256(590, 20);
assert mult_high_a = Uint256(0, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ use crate::{
squash_dict_inner_used_accesses_assert,
},
uint256_utils::{
split_64, uint256_add, uint256_signed_nn, uint256_sqrt, uint256_unsigned_div_rem,
split_64, uint256_add, uint256_mul_div_mod, uint256_signed_nn, uint256_sqrt,
uint256_unsigned_div_rem,
},
usort::{
usort_body, usort_enter_scope, verify_multiplicity_assert,
Expand Down Expand Up @@ -452,6 +453,9 @@ impl HintProcessor for BuiltinHintProcessor {
chained_ec_op_random_ec_point_hint(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
hint_code::RECOVER_Y => recover_y_hint(vm, &hint_data.ids_data, &hint_data.ap_tracking),
hint_code::UINT256_MUL_DIV_MOD => {
uint256_mul_div_mod(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
#[cfg(feature = "skip_next_instruction_hint")]
hint_code::SKIP_NEXT_INSTRUCTION => skip_next_instruction(vm),
code => Err(HintError::UnknownHint(code.to_string())),
Expand Down
12 changes: 12 additions & 0 deletions src/hint_processor/builtin_hint_processor/hint_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,18 @@ ids.quotient.high = quotient >> 128
ids.remainder.low = remainder & ((1 << 128) - 1)
ids.remainder.high = remainder >> 128"#;

pub(crate) const UINT256_MUL_DIV_MOD: &str = r#"a = (ids.a.high << 128) + ids.a.low
b = (ids.b.high << 128) + ids.b.low
div = (ids.div.high << 128) + ids.div.low
quotient, remainder = divmod(a * b, div)
ids.quotient_low.low = quotient & ((1 << 128) - 1)
ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)
ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)
ids.quotient_high.high = quotient >> 384
ids.remainder.low = remainder & ((1 << 128) - 1)
ids.remainder.high = remainder >> 128"#;

pub(crate) const USORT_ENTER_SCOPE: &str =
"vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))";
pub(crate) const USORT_BODY: &str = r#"from collections import defaultdict
Expand Down
197 changes: 196 additions & 1 deletion src/hint_processor/builtin_hint_processor/uint256_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use crate::{
vm::{errors::hint_errors::HintError, vm_core::VirtualMachine},
};
use felt::Felt252;
use num_integer::div_rem;
use num_bigint::BigUint;
use num_integer::{div_rem, Integer};
use num_traits::{One, Signed, Zero};
/*
Implements hint:
Expand Down Expand Up @@ -217,9 +218,90 @@ pub fn uint256_unsigned_div_rem(
Ok(())
}

/* Implements Hint:
%{
a = (ids.a.high << 128) + ids.a.low
b = (ids.b.high << 128) + ids.b.low
div = (ids.div.high << 128) + ids.div.low
quotient, remainder = divmod(a * b, div)
ids.quotient_low.low = quotient & ((1 << 128) - 1)
ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)
ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)
ids.quotient_high.high = quotient >> 384
ids.remainder.low = remainder & ((1 << 128) - 1)
ids.remainder.high = remainder >> 128
%}
*/
pub fn uint256_mul_div_mod(
vm: &mut VirtualMachine,
ids_data: &HashMap<String, HintReference>,
ap_tracking: &ApTracking,
) -> Result<(), HintError> {
// Extract variables
let a_addr = get_relocatable_from_var_name("a", vm, ids_data, ap_tracking)?;
let b_addr = get_relocatable_from_var_name("b", vm, ids_data, ap_tracking)?;
let div_addr = get_relocatable_from_var_name("div", vm, ids_data, ap_tracking)?;
let quotient_low_addr =
get_relocatable_from_var_name("quotient_low", vm, ids_data, ap_tracking)?;
let quotient_high_addr =
get_relocatable_from_var_name("quotient_high", vm, ids_data, ap_tracking)?;
let remainder_addr = get_relocatable_from_var_name("remainder", vm, ids_data, ap_tracking)?;

let a_low = vm.get_integer(a_addr)?;
let a_high = vm.get_integer((a_addr + 1_usize)?)?;
let b_low = vm.get_integer(b_addr)?;
let b_high = vm.get_integer((b_addr + 1_usize)?)?;
let div_low = vm.get_integer(div_addr)?;
let div_high = vm.get_integer((div_addr + 1_usize)?)?;
let a_low = a_low.as_ref();
let a_high = a_high.as_ref();
let b_low = b_low.as_ref();
let b_high = b_high.as_ref();
let div_low = div_low.as_ref();
let div_high = div_high.as_ref();

// Main Logic
let a = a_high.shl(128_usize) + a_low;
let b = b_high.shl(128_usize) + b_low;
let div = div_high.shl(128_usize) + div_low;
let (quotient, remainder) = (a.to_biguint() * b.to_biguint()).div_mod_floor(&div.to_biguint());

// ids.quotient_low.low
vm.insert_value(
quotient_low_addr,
Felt252::from(&quotient & &BigUint::from(u128::MAX)),
)?;
// ids.quotient_low.high
vm.insert_value(
(quotient_low_addr + 1)?,
Felt252::from((&quotient).shr(128_u32) & &BigUint::from(u128::MAX)),
)?;
// ids.quotient_high.low
vm.insert_value(
quotient_high_addr,
Felt252::from((&quotient).shr(256_u32) & &BigUint::from(u128::MAX)),
)?;
// ids.quotient_high.high
vm.insert_value(
(quotient_high_addr + 1)?,
Felt252::from((&quotient).shr(384_u32)),
)?;
//ids.remainder.low
vm.insert_value(
remainder_addr,
Felt252::from(&remainder & &BigUint::from(u128::MAX)),
)?;
//ids.remainder.high
vm.insert_value((remainder_addr + 1)?, Felt252::from(remainder.shr(128_u32)))?;

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use crate::hint_processor::builtin_hint_processor::hint_code;
use crate::vm::vm_memory::memory_segments::MemorySegmentManager;
use crate::{
any_box,
Expand Down Expand Up @@ -573,4 +655,117 @@ mod tests {
z == MaybeRelocatable::from(Felt252::new(10))
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_unsigned_div_rem_invalid_memory_insert_2() {
let hint_code = "a = (ids.a.high << 128) + ids.a.low\ndiv = (ids.div.high << 128) + ids.div.low\nquotient, remainder = divmod(a, div)\n\nids.quotient.low = quotient & ((1 << 128) - 1)\nids.quotient.high = quotient >> 128\nids.remainder.low = remainder & ((1 << 128) - 1)\nids.remainder.high = remainder >> 128";
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 10;
//Create hint_data
let ids_data =
non_continuous_ids_data![("a", -6), ("div", -4), ("quotient", 0), ("remainder", 2)];
//Insert ids into memory
vm.segments = segments![
((1, 4), 89),
((1, 5), 72),
((1, 6), 3),
((1, 7), 7),
((1, 11), 1)
];
//Execute the hint
assert_matches!(
run_hint!(vm, ids_data, hint_code),
Err(HintError::Memory(
MemoryError::InconsistentMemory(
x,
y,
z,
)
)) if x == Relocatable::from((1, 11)) &&
y == MaybeRelocatable::from(Felt252::one()) &&
z == MaybeRelocatable::from(Felt252::zero())
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_mul_div_mod_ok() {
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 10;
//Create hint_data
let ids_data = non_continuous_ids_data![
("a", -8),
("b", -6),
("div", -4),
("quotient_low", 0),
("quotient_high", 2),
("remainder", 4)
];
//Insert ids into memory
vm.segments = segments![
((1, 2), 89),
((1, 3), 72),
((1, 4), 3),
((1, 5), 7),
((1, 6), 107),
((1, 7), 114)
];
//Execute the hint
assert_matches!(
run_hint!(vm, ids_data, hint_code::UINT256_MUL_DIV_MOD),
Ok(())
);
//Check hint memory inserts
//ids.quotient.low, ids.quotient.high, ids.remainder.low, ids.remainder.high
check_memory![
vm.segments.memory,
((1, 10), 143276786071974089879315624181797141668),
((1, 11), 4),
((1, 12), 0),
((1, 13), 0),
//((1, 14), 322372768661941702228460154409043568767),
((1, 15), 101)
];
assert_eq!(
vm.segments
.memory
.get_integer((1, 14).into())
.unwrap()
.as_ref(),
&felt_str!("322372768661941702228460154409043568767")
)
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_mul_div_mod_missing_ids() {
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 10;
//Create hint_data
let ids_data = non_continuous_ids_data![
("a", -8),
("b", -6),
("div", -4),
("quotient", 0),
("remainder", 2)
];
//Insert ids into memory
vm.segments = segments![
((1, 2), 89),
((1, 3), 72),
((1, 4), 3),
((1, 5), 7),
((1, 6), 107),
((1, 7), 114)
];
//Execute the hint
assert_matches!(
run_hint!(vm, ids_data, hint_code::UINT256_MUL_DIV_MOD),
Err(HintError::UnknownIdentifier(s)) if s == "quotient_low"
);
}
}

0 comments on commit abadc9d

Please sign in to comment.