use super::monty::monty_modpow; use super::BigUint; use crate::big_digit::{self, BigDigit}; use num_integer::Integer; use num_traits::{One, Pow, ToPrimitive, Zero}; impl Pow<&BigUint> for BigUint { type Output = BigUint; #[inline] fn pow(self, exp: &BigUint) -> BigUint { if self.is_one() || exp.is_zero() { BigUint::one() } else if self.is_zero() { BigUint::zero() } else if let Some(exp) = exp.to_u64() { self.pow(exp) } else if let Some(exp) = exp.to_u128() { self.pow(exp) } else { // At this point, `self >= 2` and `exp >= 2¹²⁸`. The smallest possible result given // `2.pow(2¹²⁸)` would require far more memory than 64-bit targets can address! panic!("memory overflow") } } } impl Pow<BigUint> for BigUint { type Output = BigUint; #[inline] fn pow(self, exp: BigUint) -> BigUint { Pow::pow(self, &exp) } } impl Pow<&BigUint> for &BigUint { type Output = BigUint; #[inline] fn pow(self, exp: &BigUint) -> BigUint { if self.is_one() || exp.is_zero() { BigUint::one() } else if self.is_zero() { BigUint::zero() } else { self.clone().pow(exp) } } } impl Pow<BigUint> for &BigUint { type Output = BigUint; #[inline] fn pow(self, exp: BigUint) -> BigUint { Pow::pow(self, &exp) } } macro_rules! pow_impl { ($T:ty) => { impl Pow<$T> for BigUint { type Output = BigUint; fn pow(self, mut exp: $T) -> BigUint { if exp == 0 { return BigUint::one(); } let mut base = self; while exp & 1 == 0 { base = &base * &base; exp >>= 1; } if exp == 1 { return base; } let mut acc = base.clone(); while exp > 1 { exp >>= 1; base = &base * &base; if exp & 1 == 1 { acc *= &base; } } acc } } impl Pow<&$T> for BigUint { type Output = BigUint; #[inline] fn pow(self, exp: &$T) -> BigUint { Pow::pow(self, *exp) } } impl Pow<$T> for &BigUint { type Output = BigUint; #[inline] fn pow(self, exp: $T) -> BigUint { if exp == 0 { return BigUint::one(); } Pow::pow(self.clone(), exp) } } impl Pow<&$T> for &BigUint { type Output = BigUint; #[inline] fn pow(self, exp: &$T) -> BigUint { Pow::pow(self, *exp) } } }; } pow_impl!(u8); pow_impl!(u16); pow_impl!(u32); pow_impl!(u64); pow_impl!(usize); pow_impl!(u128); pub(super) fn modpow(x: &BigUint, exponent: &BigUint, modulus: &BigUint) -> BigUint { assert!( !modulus.is_zero(), "attempt to calculate with zero modulus!" ); if modulus.is_odd() { // For an odd modulus, we can use Montgomery multiplication in base 2^32. monty_modpow(x, exponent, modulus) } else { // Otherwise do basically the same as `num::pow`, but with a modulus. plain_modpow(x, &exponent.data, modulus) } } fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> BigUint { assert!( !modulus.is_zero(), "attempt to calculate with zero modulus!" ); let i = match exp_data.iter().position(|&r| r != 0) { None => return BigUint::one(), Some(i) => i, }; let mut base = base % modulus; for _ in 0..i { for _ in 0..big_digit::BITS { base = &base * &base % modulus; } } let mut r = exp_data[i]; let mut b = 0u8; while r.is_even() { base = &base * &base % modulus; r >>= 1; b += 1; } let mut exp_iter = exp_data[i + 1..].iter(); if exp_iter.len() == 0 && r.is_one() { return base; } let mut acc = base.clone(); r >>= 1; b += 1; { let mut unit = |exp_is_odd| { base = &base * &base % modulus; if exp_is_odd { acc *= &base; acc %= modulus; } }; if let Some(&last) = exp_iter.next_back() { // consume exp_data[i] for _ in b..big_digit::BITS { unit(r.is_odd()); r >>= 1; } // consume all other digits before the last for &r in exp_iter { let mut r = r; for _ in 0..big_digit::BITS { unit(r.is_odd()); r >>= 1; } } r = last; } debug_assert_ne!(r, 0); while !r.is_zero() { unit(r.is_odd()); r >>= 1; } } acc } #[test] fn test_plain_modpow() { let two = &BigUint::from(2u32); let modulus = BigUint::from(0x1100u32); let exp = vec![0, 0b1]; assert_eq!( two.pow(0b1_00000000_u32) % &modulus, plain_modpow(&two, &exp, &modulus) ); let exp = vec![0, 0b10]; assert_eq!( two.pow(0b10_00000000_u32) % &modulus, plain_modpow(&two, &exp, &modulus) ); let exp = vec![0, 0b110010]; assert_eq!( two.pow(0b110010_00000000_u32) % &modulus, plain_modpow(&two, &exp, &modulus) ); let exp = vec![0b1, 0b1]; assert_eq!( two.pow(0b1_00000001_u32) % &modulus, plain_modpow(&two, &exp, &modulus) ); let exp = vec![0b1100, 0, 0b1]; assert_eq!( two.pow(0b1_00000000_00001100_u32) % &modulus, plain_modpow(&two, &exp, &modulus) ); } #[test] fn test_pow_biguint() { let base = BigUint::from(5u8); let exponent = BigUint::from(3u8); assert_eq!(BigUint::from(125u8), base.pow(exponent)); }