Skip to content

Commit

Permalink
Bring the overflow behavior in bit shifts in sync with std (#395)
Browse files Browse the repository at this point in the history
- `const fn` bit shifts for `Uint` return the overflow status as `CtChoice` (and set the result to zero in that case, which is documented, so it's a part of the API now). `Option` would be better for the vartime shifts, but its methods are not `const` yet in stable.
- `shl/shr` for `BoxedUint` return `(Self, Choice)` (not `CtOption` since most of its methods need the type to be `ConditionallySelectable`, which `BoxedUint` isn't). The vartime equivalents return `Option<Self>`.
- operator impls panic on overflow (which is the default behavior for built-in integers)
- made the implementations in `uint/shl.rs` and `shr.rs` more uniform and improved vartime shift performance (before it was calling a constant-time shift-by-no-more-than-limb which added some overhead)  
- improved constant-time shift performance for `BoxedUint` by reducing the amount of allocations
- added an optimized `BoxedUint::shl1()` implementation
- added some inlines for `Limb` methods which improved shift performance noticeably
- added more benchmarks for shifts and simplify benchmark hierarchy a little (create test group directly in the respective function)
- fixed an inefficiency in `Uint` shifts: we need to iterate to log2(BITS-1), not log2(BITS), because that's the maximum size of the shift.
- Renamed `sh(r/l)1_with_overflow()` to `sh(r/l)1_with_carry` to avoid confusion - in the context of shifts we call the shift being too large an overflow.
  • Loading branch information
fjarri authored Dec 13, 2023
1 parent 7359ebc commit 55312b6
Show file tree
Hide file tree
Showing 26 changed files with 641 additions and 251 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ name = "boxed_residue"
harness = false
required-features = ["alloc"]

[[bench]]
name = "boxed_uint"
harness = false
required-features = ["alloc"]

[[bench]]
name = "dyn_residue"
harness = false
Expand Down
48 changes: 48 additions & 0 deletions benches/boxed_uint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
use crypto_bigint::BoxedUint;
use rand_core::OsRng;

/// Size of `BoxedUint` to use in benchmark.
const UINT_BITS: u32 = 4096;

fn bench_shifts(c: &mut Criterion) {
let mut group = c.benchmark_group("bit shifts");

group.bench_function("shl_vartime", |b| {
b.iter_batched(
|| BoxedUint::random(&mut OsRng, UINT_BITS),
|x| black_box(x.shl_vartime(UINT_BITS / 2 + 10)),
BatchSize::SmallInput,
)
});

group.bench_function("shl", |b| {
b.iter_batched(
|| BoxedUint::random(&mut OsRng, UINT_BITS),
|x| x.shl(UINT_BITS / 2 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shr_vartime", |b| {
b.iter_batched(
|| BoxedUint::random(&mut OsRng, UINT_BITS),
|x| black_box(x.shr_vartime(UINT_BITS / 2 + 10)),
BatchSize::SmallInput,
)
});

group.bench_function("shr", |b| {
b.iter_batched(
|| BoxedUint::random(&mut OsRng, UINT_BITS),
|x| x.shr(UINT_BITS / 2 + 10),
BatchSize::SmallInput,
)
});

group.finish();
}

criterion_group!(benches, bench_shifts);

criterion_main!(benches);
70 changes: 50 additions & 20 deletions benches/uint.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use criterion::{
black_box, criterion_group, criterion_main, measurement::Measurement, BatchSize,
BenchmarkGroup, Criterion,
};
use crypto_bigint::{Limb, NonZero, Random, Reciprocal, U128, U2048, U256};
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
use crypto_bigint::{Limb, NonZero, Random, Reciprocal, Uint, U128, U2048, U256};
use rand_core::OsRng;

fn bench_division<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
fn bench_division(c: &mut Criterion) {
let mut group = c.benchmark_group("wrapping ops");

group.bench_function("div/rem, U256/U128, full size", |b| {
b.iter_batched(
|| {
Expand Down Expand Up @@ -69,9 +68,13 @@ fn bench_division<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
BatchSize::SmallInput,
)
});

group.finish();
}

fn bench_shifts<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
fn bench_shl(c: &mut Criterion) {
let mut group = c.benchmark_group("left shift");

group.bench_function("shl_vartime, small, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shl_vartime(10), BatchSize::SmallInput)
});
Expand All @@ -84,16 +87,54 @@ fn bench_shifts<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
)
});

group.bench_function("shl_vartime_wide, large, U2048", |b| {
b.iter_batched(
|| (U2048::ONE, U2048::ONE),
|x| Uint::shl_vartime_wide(x, 1024 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shl, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shl(1024 + 10), BatchSize::SmallInput)
});

group.finish();
}

fn bench_shr(c: &mut Criterion) {
let mut group = c.benchmark_group("right shift");

group.bench_function("shr_vartime, small, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shr_vartime(10), BatchSize::SmallInput)
});

group.bench_function("shr_vartime, large, U2048", |b| {
b.iter_batched(
|| U2048::ONE,
|x| x.shr_vartime(1024 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shr_vartime_wide, large, U2048", |b| {
b.iter_batched(
|| (U2048::ONE, U2048::ONE),
|x| Uint::shr_vartime_wide(x, 1024 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shr, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shr(1024 + 10), BatchSize::SmallInput)
});

group.finish();
}

fn bench_inv_mod<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
fn bench_inv_mod(c: &mut Criterion) {
let mut group = c.benchmark_group("modular ops");

group.bench_function("inv_odd_mod, U256", |b| {
b.iter_batched(
|| {
Expand Down Expand Up @@ -144,21 +185,10 @@ fn bench_inv_mod<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
BatchSize::SmallInput,
)
});
}

fn bench_wrapping_ops(c: &mut Criterion) {
let mut group = c.benchmark_group("wrapping ops");
bench_division(&mut group);
group.finish();
}

fn bench_modular_ops(c: &mut Criterion) {
let mut group = c.benchmark_group("modular ops");
bench_shifts(&mut group);
bench_inv_mod(&mut group);
group.finish();
}

criterion_group!(benches, bench_wrapping_ops, bench_modular_ops);
criterion_group!(benches, bench_shl, bench_shr, bench_division, bench_inv_mod);

criterion_main!(benches);
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
//! U256::from_be_hex("ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551");
//!
//! // Compute `MODULUS` shifted right by 1 at compile time
//! pub const MODULUS_SHR1: U256 = MODULUS.shr(1);
//! pub const MODULUS_SHR1: U256 = MODULUS.shr(1).0;
//! ```
//!
//! Note that large constant computations may accidentally trigger a the `const_eval_limit` of the compiler.
Expand Down
1 change: 1 addition & 0 deletions src/limb/bit_not.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::ops::Not;

impl Limb {
/// Calculates `!a`.
#[inline(always)]
pub const fn not(self) -> Self {
Limb(!self.0)
}
Expand Down
1 change: 1 addition & 0 deletions src/limb/bit_or.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::ops::{BitOr, BitOrAssign};

impl Limb {
/// Calculates `a | b`.
#[inline(always)]
pub const fn bitor(self, rhs: Self) -> Self {
Limb(self.0 | rhs.0)
}
Expand Down
1 change: 1 addition & 0 deletions src/limb/bit_xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::ops::BitXor;

impl Limb {
/// Calculates `a ^ b`.
#[inline(always)]
pub const fn bitxor(self, rhs: Self) -> Self {
Limb(self.0 ^ rhs.0)
}
Expand Down
4 changes: 4 additions & 0 deletions src/limb/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@ use super::Limb;

impl Limb {
/// Calculate the number of bits needed to represent this number.
#[inline(always)]
pub const fn bits(self) -> u32 {
Limb::BITS - self.0.leading_zeros()
}

/// Calculate the number of leading zeros in the binary representation of this number.
#[inline(always)]
pub const fn leading_zeros(self) -> u32 {
self.0.leading_zeros()
}

/// Calculate the number of trailing zeros in the binary representation of this number.
#[inline(always)]
pub const fn trailing_zeros(self) -> u32 {
self.0.trailing_zeros()
}

/// Calculate the number of trailing ones the binary representation of this number.
#[inline(always)]
pub const fn trailing_ones(self) -> u32 {
self.0.trailing_ones()
}
Expand Down
2 changes: 1 addition & 1 deletion src/limb/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ impl Limb {
}

/// Perform saturating multiplication.
#[inline]
#[inline(always)]
pub const fn saturating_mul(&self, rhs: Self) -> Self {
Limb(self.0.saturating_mul(rhs.0))
}
Expand Down
6 changes: 6 additions & 0 deletions src/limb/shl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ impl Limb {
pub const fn shl(self, shift: u32) -> Self {
Limb(self.0 << shift)
}

/// Computes `self << 1` and return the result and the carry (0 or 1).
#[inline(always)]
pub(crate) const fn shl1(self) -> (Self, Self) {
(Self(self.0 << 1), Self(self.0 >> Self::HI_BIT))
}
}

impl Shl<u32> for Limb {
Expand Down
6 changes: 6 additions & 0 deletions src/limb/shr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ impl Limb {
pub const fn shr(self, shift: u32) -> Self {
Limb(self.0 >> shift)
}

/// Computes `self >> 1` and return the result and the carry (0 or `1 << HI_BIT`).
#[inline(always)]
pub(crate) const fn shr1(self) -> (Self, Self) {
(Self(self.0 >> 1), Self(self.0 << Self::HI_BIT))
}
}

impl Shr<u32> for Limb {
Expand Down
2 changes: 1 addition & 1 deletion src/modular/div_by_2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub(crate) fn div_by_2<const LIMBS: usize>(a: &Uint<LIMBS>, modulus: &Uint<LIMBS
// ("+1" because both `a` and `modulus` are odd, we lose 0.5 in each integer division).
// This will not overflow, so we can just use wrapping operations.

let (half, is_odd) = a.shr1_with_overflow();
let (half, is_odd) = a.shr1_with_carry();
let half_modulus = modulus.shr1();

let if_even = half;
Expand Down
16 changes: 15 additions & 1 deletion src/uint/boxed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ mod rand;
use crate::{Integer, Limb, NonZero, Uint, Word, Zero, U128, U64};
use alloc::{boxed::Box, vec, vec::Vec};
use core::{fmt, mem};
use subtle::{Choice, ConstantTimeEq};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};

#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
Expand Down Expand Up @@ -253,6 +253,20 @@ impl BoxedUint {

limbs.into()
}

/// Set the value of `self` to zero in-place.
pub(crate) fn set_to_zero(&mut self) {
self.limbs.as_mut().fill(Limb::ZERO)
}

/// Set the value of `self` to zero in-place if `choice` is truthy.
pub(crate) fn conditional_set_to_zero(&mut self, choice: Choice) {
let nlimbs = self.nlimbs();
let limbs = self.limbs.as_mut();
for i in 0..nlimbs {
limbs[i] = Limb::conditional_select(&limbs[i], &Limb::ZERO, choice);
}
}
}

impl NonZero<BoxedUint> {
Expand Down
2 changes: 1 addition & 1 deletion src/uint/boxed/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ mod tests {
fn uint_with_bits_at(positions: &[u32]) -> BoxedUint {
let mut result = BoxedUint::zero_with_precision(256);
for &pos in positions {
result |= BoxedUint::one_with_precision(256).shl_vartime(pos);
result |= BoxedUint::one_with_precision(256).shl_vartime(pos).unwrap();
}
result
}
Expand Down
8 changes: 5 additions & 3 deletions src/uint/boxed/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ impl BoxedUint {
let mb = rhs.bits();
let mut bd = self.bits_precision() - mb;
let mut rem = self.clone();
let mut c = rhs.shl_vartime(bd);
// Will not overflow since `bd < bits_precision`
let mut c = rhs.shl_vartime(bd).expect("shift within range");

loop {
let (r, borrow) = rem.sbb(&c, Limb::ZERO);
Expand Down Expand Up @@ -77,7 +78,7 @@ impl BoxedUint {
let bits_precision = self.bits_precision();
let mut rem = self.clone();
let mut quo = Self::zero_with_precision(bits_precision);
let mut c = rhs.shl(bits_precision - mb);
let (mut c, _overflow) = rhs.shl(bits_precision - mb);
let mut i = bits_precision;
let mut done = Choice::from(0u8);

Expand Down Expand Up @@ -110,7 +111,8 @@ impl BoxedUint {
let mut bd = self.bits_precision() - mb;
let mut remainder = self.clone();
let mut quotient = Self::zero_with_precision(self.bits_precision());
let mut c = rhs.shl_vartime(bd);
// Will not overflow since `bd < bits_precision`
let mut c = rhs.shl_vartime(bd).expect("shift within range");

loop {
let (mut r, borrow) = remainder.sbb(&c, Limb::ZERO);
Expand Down
10 changes: 5 additions & 5 deletions src/uint/boxed/inv_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ impl BoxedUint {

// Decompose `modulus = s * 2^k` where `s` is odd
let k = modulus.trailing_zeros();
let s = modulus.shr(k);
let s = modulus >> k;

// Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
// Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
Expand All @@ -26,7 +26,7 @@ impl BoxedUint {
let (m_odd_inv, _is_some) = s.inv_mod2k(k); // `s` is odd, so this always exists

// This part is mod 2^k
let mask = Self::one().shl(k).wrapping_sub(&Self::one());
let mask = (Self::one() << k).wrapping_sub(&Self::one());
let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)).bitand(&mask);

// Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
Expand Down Expand Up @@ -126,9 +126,9 @@ impl BoxedUint {
let cyy = new_u.conditional_adc_assign(modulus, cy);
debug_assert!(bool::from(cy.ct_eq(&cyy)));

let (new_a, overflow) = a.shr1_with_overflow();
debug_assert!(bool::from(!modulus_is_odd | !overflow));
let (mut new_u, cy) = new_u.shr1_with_overflow();
let (new_a, carry) = a.shr1_with_carry();
debug_assert!(bool::from(!modulus_is_odd | !carry));
let (mut new_u, cy) = new_u.shr1_with_carry();
let cy = new_u.conditional_adc_assign(&m1hp, cy);
debug_assert!(bool::from(!modulus_is_odd | !cy));

Expand Down
Loading

0 comments on commit 55312b6

Please sign in to comment.