Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster constant-time division #643

Merged
merged 9 commits into from
Aug 18, 2024
8 changes: 8 additions & 0 deletions benches/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,14 @@ fn bench_sqrt(c: &mut Criterion) {
BatchSize::SmallInput,
)
});

group.bench_function("sqrt_vartime, U256", |b| {
b.iter_batched(
|| U256::random(&mut OsRng),
|x| x.sqrt_vartime(),
BatchSize::SmallInput,
)
});
}

criterion_group!(
Expand Down
6 changes: 0 additions & 6 deletions src/limb/shl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@ impl Limb {
pub const fn shl(self, shift: u32) -> Self {
Limb(self.0 << shift)
}

/// Computes `self << 1` and return the result and the carry (0 or 1).
#[inline(always)]
pub(crate) const fn shl1(self) -> (Self, Self) {
(Self(self.0 << 1), Self(self.0 >> Self::HI_BIT))
}
}

macro_rules! impl_shl {
Expand Down
11 changes: 10 additions & 1 deletion src/non_zero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use serdect::serde::{
};

/// Wrapper type for non-zero integers.
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, PartialOrd, Ord)]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
#[repr(transparent)]
pub struct NonZero<T>(pub(crate) T);

Expand Down Expand Up @@ -210,6 +210,15 @@ where
}
}

impl<T> Default for NonZero<T>
tarcieri marked this conversation as resolved.
Show resolved Hide resolved
where
T: Constants,
{
fn default() -> Self {
Self(T::ONE)
}
}

impl<T> Deref for NonZero<T> {
type Target = T;

Expand Down
134 changes: 103 additions & 31 deletions src/uint/boxed/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
use crate::{
uint::{boxed, div_limb::div3by2},
BoxedUint, CheckedDiv, ConstChoice, ConstantTimeSelect, DivRemLimb, Limb, NonZero, Reciprocal,
RemLimb, Wrapping,
RemLimb, Word, Wrapping,
};
use core::ops::{Div, DivAssign, Rem, RemAssign};
use subtle::{Choice, ConstantTimeLess, CtOption};
use subtle::CtOption;

impl BoxedUint {
/// Computes `self / rhs` using a pre-made reciprocal,
Expand Down Expand Up @@ -118,41 +118,113 @@ impl BoxedUint {
/// Perform checked division, returning a [`CtOption`] which `is_some`
/// only if the rhs != 0
pub fn checked_div(&self, rhs: &Self) -> CtOption<Self> {
let q = self.div_rem_unchecked(rhs).0;
CtOption::new(q, !rhs.is_zero())
let is_nz = rhs.is_nonzero();
let nz = NonZero(Self::ct_select(
&Self::one_with_precision(self.bits_precision()),
rhs,
is_nz,
));
let q = self.div_rem_unchecked(&nz).0;
CtOption::new(q, is_nz)
}

/// Computes `self` / `rhs`, returns the quotient (q), remainder (r) without checking if `rhs`
/// is zero.
///
/// This function is constant-time with respect to both `self` and `rhs`.
fn div_rem_unchecked(&self, rhs: &Self) -> (Self, Self) {
debug_assert_eq!(self.bits_precision(), rhs.bits_precision());
let mb = rhs.bits();
let bits_precision = self.bits_precision();
let mut rem = self.clone();
let mut quo = Self::zero_with_precision(bits_precision);
let (mut c, _overflow) = rhs.overflowing_shl(bits_precision - mb);
let mut i = bits_precision;
let mut done = Choice::from(0u8);

loop {
let (mut r, borrow) = rem.sbb(&c, Limb::ZERO);
rem.ct_assign(&r, !(Choice::from((borrow.0 & 1) as u8) | done));
r = quo.bitor(&Self::one());
quo.ct_assign(&r, !(Choice::from((borrow.0 & 1) as u8) | done));
if i == 0 {
break;
// Based on Section 4.3.1, of The Art of Computer Programming, Volume 2, by Donald E. Knuth.
// Further explanation at https://janmr.com/blog/2014/04/basic-multiple-precision-long-division/

let size = self.limbs.len();
assert_eq!(
size,
rhs.limbs.len(),
"the precision of the divisor must match the dividend"
);

// Short circuit for single-word precision
if size == 1 {
let (quo, rem_limb) = self.div_rem_limb(rhs.limbs[0].to_nz().expect("zero divisor"));
let mut rem = Self::zero_with_precision(self.bits_precision());
rem.limbs[0] = rem_limb;
return (quo, rem);
}

let dbits = rhs.bits();
assert!(dbits > 0, "zero divisor");
let dwords = (dbits + Limb::BITS - 1) / Limb::BITS;
let lshift = (Limb::BITS - (dbits % Limb::BITS)) % Limb::BITS;

// Shift entire divisor such that the high bit is set
let mut y = rhs.shl((size as u32) * Limb::BITS - dbits).to_limbs();
tarcieri marked this conversation as resolved.
Show resolved Hide resolved
// Shift the dividend to align the words
let (x, mut x_hi) = self.shl_limb(lshift);
let mut x = x.to_limbs();
let mut xi = size - 1;
let mut x_lo = x[size - 1];
let mut i;
let mut carry;

let reciprocal = Reciprocal::new(y[size - 1].to_nz().expect("zero divisor"));

while xi > 0 {
// Divide high dividend words by the high divisor word to estimate the quotient word
let (mut quo, _) = div3by2(x_hi.0, x_lo.0, x[xi - 1].0, &reciprocal, y[size - 2].0);

// This loop is a no-op once xi is smaller than the number of words in the divisor
let done = ConstChoice::from_u32_lt(xi as u32, dwords - 1);
quo = done.select_word(quo, 0);

// Subtract q*divisor from the dividend
carry = Limb::ZERO;
let mut borrow = Limb::ZERO;
let mut tmp;
i = 0;
while i <= xi {
(tmp, carry) = Limb::ZERO.mac(y[size - xi + i - 1], Limb(quo), carry);
(x[i], borrow) = x[i].sbb(tmp, borrow);
i += 1;
}
i -= 1;
// when `i < mb`, the computation is actually done, so we ensure `quo` and `rem`
// aren't modified further (but do the remaining iterations anyway to be constant-time)
done = i.ct_lt(&mb);
c.shr1_assign();
quo.ct_assign(&quo.shl1(), !done);
(_, borrow) = x_hi.sbb(carry, borrow);

// If the subtraction borrowed, then decrement q and add back the divisor
// The probability of this being needed is very low, about 2/(Limb::MAX+1)
let ct_borrow = ConstChoice::from_word_mask(borrow.0);
carry = Limb::ZERO;
i = 0;
while i <= xi {
(x[i], carry) = x[i].adc(
Limb::select(Limb::ZERO, y[size - xi + i - 1], ct_borrow),
carry,
);
i += 1;
}
quo = ct_borrow.select_word(quo, quo.saturating_sub(1));

// Store the quotient within dividend and set x_hi to the current highest word
x_hi = Limb::select(x[xi], x_hi, done);
x[xi] = Limb::select(Limb(quo), x[xi], done);
x_lo = Limb::select(x[xi - 1], x_lo, done);
xi -= 1;
}

let limb_div = ConstChoice::from_u32_eq(1, dwords);
// Calculate quotient and remainder for the case where the divisor is a single word
let (quo2, rem2) = div3by2(x_hi.0, x_lo.0, 0, &reciprocal, 0);

// Adjust the quotient for single limb division
x[0] = Limb::select(x[0], Limb(quo2), limb_div);

// Copy out the remainder
y[0] = Limb::select(x[0], Limb(rem2 as Word), limb_div);
i = 1;
while i < size {
y[i] = Limb::select(Limb::ZERO, x[i], ConstChoice::from_u32_lt(i as u32, dwords));
y[i] = Limb::select(y[i], x_hi, ConstChoice::from_u32_eq(i as u32, dwords - 1));
i += 1;
}

(quo, rem)
(
Self { limbs: x }.shr((dwords - 1) * Limb::BITS),
Self { limbs: y }.shr(lshift),
)
}
}

Expand Down
27 changes: 0 additions & 27 deletions src/uint/boxed/shl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,6 @@ impl BoxedUint {
success.map(|_| result)
}

/// Computes `self << 1` in constant-time.
pub(crate) fn shl1(&self) -> Self {
let mut ret = self.clone();
ret.shl1_assign();
ret
}

/// Computes `self << 1` in-place in constant-time.
pub(crate) fn shl1_assign(&mut self) {
let mut carry = self.limbs[0].0 >> Limb::HI_BIT;
self.limbs[0].shl_assign(1);
for i in 1..self.limbs.len() {
let new_carry = self.limbs[i].0 >> Limb::HI_BIT;
self.limbs[i].shl_assign(1);
self.limbs[i].0 |= carry;
carry = new_carry
}
}

/// Computes `self << shift` where `0 <= shift < Limb::BITS`,
/// returning the result and the carry.
pub(crate) fn shl_limb(&self, shift: u32) -> (Self, Limb) {
Expand Down Expand Up @@ -230,14 +211,6 @@ impl ShlVartime for BoxedUint {
mod tests {
use super::BoxedUint;

#[test]
fn shl1_assign() {
let mut n = BoxedUint::from(0x3c442b21f19185fe433f0a65af902b8fu128);
let n_shl1 = BoxedUint::from(0x78885643e3230bfc867e14cb5f20571eu128);
n.shl1_assign();
assert_eq!(n, n_shl1);
}

#[test]
fn shl() {
let one = BoxedUint::one_with_precision(128);
Expand Down
22 changes: 10 additions & 12 deletions src/uint/boxed/sqrt.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! [`BoxedUint`] square root operations.

use subtle::{ConstantTimeEq, ConstantTimeGreater, CtOption};
use subtle::{ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, CtOption};

use crate::{BitOps, BoxedUint, ConstantTimeSelect, NonZero, SquareRoot};

Expand All @@ -23,24 +23,22 @@ impl BoxedUint {
// Repeat enough times to guarantee result has stabilized.
let mut i = 0;
let mut x_prev = x.clone(); // keep the previous iteration in case we need to roll back.
let mut nz_x = NonZero(x.clone());

// TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough.
while i < self.log2_bits() + 2 {
x_prev.limbs.clone_from_slice(&x.limbs);

// Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`

let (nz_x, is_nonzero) = (NonZero(x.clone()), x.is_nonzero());
let x_nonzero = x.is_nonzero();
let mut j = 0;
while j < nz_x.0.limbs.len() {
nz_x.0.limbs[j].conditional_assign(&x.limbs[j], x_nonzero);
j += 1;
}
let (q, _) = self.div_rem(&nz_x);

// A protection in case `self == 0`, which will make `x == 0`
let q = Self::ct_select(
&Self::zero_with_precision(self.bits_precision()),
&q,
is_nonzero,
);

x = x.wrapping_add(&q).shr1();
x.conditional_adc_assign(&q, x_nonzero);
x.shr1_assign();
i += 1;
}

Expand Down
Loading