From 344a6b9592a5c56c97459705cf95ed51edd0063b Mon Sep 17 00:00:00 2001 From: fionser Date: Sun, 4 Feb 2024 23:02:34 +0800 Subject: [PATCH] [opt] Optimizing Shoup's MulMod in Apple M1/M2 chip. --- src/prime64/less_than_62bit.rs | 98 ++++++++++++++++++++++++++++++++++ src/prime64/less_than_63bit.rs | 73 +++++++++++++++++++++++++ 2 files changed, 171 insertions(+) diff --git a/src/prime64/less_than_62bit.rs b/src/prime64/less_than_62bit.rs index 40489c6..14ea1f1 100644 --- a/src/prime64/less_than_62bit.rs +++ b/src/prime64/less_than_62bit.rs @@ -113,6 +113,53 @@ pub(crate) fn fwd_last_butterfly_avx2( ) } +#[cfg(any(target_arch = "aarch64"))] +#[inline(always)] +pub(crate) fn fwd_butterfly_scalar( + z0: u64, + z1: u64, + w: u64, + w_shoup: u64, + p: u64, + neg_p: u64, + two_p: u64, +) -> (u64, u64) { + let _ = p; + let z0 = z0.min(z0.wrapping_sub(two_p)); + + let shoup_q_u128 = (z1 as u128 * w_shoup as u128) >> 64; + let t = ((z1 as u128 * w as u128) + (shoup_q_u128 * neg_p as u128)) as u64; + + (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(two_p)) +} + +#[cfg(any(target_arch = "aarch64"))] +#[inline(always)] +pub(crate) fn fwd_last_butterfly_scalar( + z0: u64, + z1: u64, + w: u64, + w_shoup: u64, + p: u64, + neg_p: u64, + two_p: u64, +) -> (u64, u64) { + let _ = p; + let z0 = z0.min(z0.wrapping_sub(two_p)); + let z0 = z0.min(z0.wrapping_sub(p)); + + let shoup_q_u128 = (z1 as u128 * w_shoup as u128) >> 64; + let t = ((z1 as u128 * w as u128) + (shoup_q_u128 * neg_p as u128)) as u64; + + let t = t.min(t.wrapping_sub(p)); + let res = (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p)); + ( + res.0.min(res.0.wrapping_sub(p)), + res.1.min(res.1.wrapping_sub(p)), + ) +} + +#[cfg(not(target_arch = "aarch64"))] #[inline(always)] pub(crate) fn fwd_butterfly_scalar( z0: u64, @@ -130,6 +177,7 @@ pub(crate) fn fwd_butterfly_scalar( (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(two_p)) } +#[cfg(not(target_arch = "aarch64"))] #[inline(always)] pub(crate) fn fwd_last_butterfly_scalar( z0: u64, @@ -267,6 +315,55 @@ pub(crate) fn inv_last_butterfly_avx2( (y0, y1) } +#[cfg(any(target_arch = "aarch64"))] +#[inline(always)] +pub(crate) fn inv_butterfly_scalar( + z0: u64, + z1: u64, + w: u64, + w_shoup: u64, + p: u64, + neg_p: u64, + two_p: u64, +) -> (u64, u64) { + let _ = p; + + let y0 = z0.wrapping_add(z1); + let y0 = y0.min(y0.wrapping_sub(two_p)); + let t = z0.wrapping_sub(z1).wrapping_add(two_p); + + let shoup_q_u128 = (t as u128 * w_shoup as u128) >> 64; + let y1 = ((t as u128 * w as u128) + shoup_q_u128 * neg_p as u128) as u64; + + (y0, y1) +} + +#[cfg(any(target_arch = "aarch64"))] +#[inline(always)] +pub(crate) fn inv_last_butterfly_scalar( + z0: u64, + z1: u64, + w: u64, + w_shoup: u64, + p: u64, + neg_p: u64, + two_p: u64, +) -> (u64, u64) { + let _ = p; + + let y0 = z0.wrapping_add(z1); + let y0 = y0.min(y0.wrapping_sub(two_p)); + let y0 = y0.min(y0.wrapping_sub(p)); + let t = z0.wrapping_sub(z1).wrapping_add(two_p); + + let shoup_q_u128 = (t as u128 * w_shoup as u128) >> 64; + let y1 = ((t as u128 * w as u128) + shoup_q_u128 * neg_p as u128) as u64; + + let y1 = y1.min(y1.wrapping_sub(p)); + (y0, y1) +} + +#[cfg(not(target_arch = "aarch64"))] #[inline(always)] pub(crate) fn inv_butterfly_scalar( z0: u64, @@ -287,6 +384,7 @@ pub(crate) fn inv_butterfly_scalar( (y0, y1) } +#[cfg(not(target_arch = "aarch64"))] #[inline(always)] pub(crate) fn inv_last_butterfly_scalar( z0: u64, diff --git a/src/prime64/less_than_63bit.rs b/src/prime64/less_than_63bit.rs index 7090015..0499e09 100644 --- a/src/prime64/less_than_63bit.rs +++ b/src/prime64/less_than_63bit.rs @@ -113,6 +113,7 @@ pub(crate) fn fwd_last_butterfly_avx2( ) } +#[cfg(not(target_arch = "aarch64"))] #[inline(always)] pub(crate) fn fwd_butterfly_scalar( z0: u64, @@ -131,6 +132,7 @@ pub(crate) fn fwd_butterfly_scalar( (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p)) } +#[cfg(not(target_arch = "aarch64"))] #[inline(always)] pub(crate) fn fwd_last_butterfly_scalar( z0: u64, @@ -153,6 +155,52 @@ pub(crate) fn fwd_last_butterfly_scalar( ) } +#[cfg(any(target_arch = "aarch64"))] +#[inline(always)] +pub(crate) fn fwd_butterfly_scalar( + z0: u64, + z1: u64, + w: u64, + w_shoup: u64, + p: u64, + neg_p: u64, + two_p: u64, +) -> (u64, u64) { + let _ = two_p; + let z0 = z0.min(z0.wrapping_sub(p)); + + let shoup_q_u128 = (z1 as u128 * w_shoup as u128) >> 64; + let t = ((z1 as u128 * w as u128) + (shoup_q_u128 * neg_p as u128)) as u64; + + let t = t.min(t.wrapping_sub(p)); + (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p)) +} + +#[cfg(any(target_arch = "aarch64"))] +#[inline(always)] +pub(crate) fn fwd_last_butterfly_scalar( + z0: u64, + z1: u64, + w: u64, + w_shoup: u64, + p: u64, + neg_p: u64, + two_p: u64, +) -> (u64, u64) { + let _ = two_p; + let z0 = z0.min(z0.wrapping_sub(p)); + + let shoup_q_u128 = (z1 as u128 * w_shoup as u128) >> 64; + let t = ((z1 as u128 * w as u128) + (shoup_q_u128 * neg_p as u128)) as u64; + + let t = t.min(t.wrapping_sub(p)); + let res = (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p)); + ( + res.0.min(res.0.wrapping_sub(p)), + res.1.min(res.1.wrapping_sub(p)), + ) +} + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(feature = "nightly")] #[inline(always)] @@ -210,6 +258,31 @@ pub(crate) fn inv_butterfly_avx2( (y0, y1) } +#[cfg(any(target_arch = "aarch64"))] +#[inline(always)] +pub(crate) fn inv_butterfly_scalar( + z0: u64, + z1: u64, + w: u64, + w_shoup: u64, + p: u64, + neg_p: u64, + two_p: u64, +) -> (u64, u64) { + let _ = two_p; + + let y0 = z0.wrapping_add(z1); + let y0 = y0.min(y0.wrapping_sub(p)); + let t = z0.wrapping_sub(z1).wrapping_add(p); + + let shoup_q_u128 = (t as u128 * w_shoup as u128) >> 64; + let y1 = ((t as u128 * w as u128) + shoup_q_u128 * neg_p as u128) as u64; + + let y1 = y1.min(y1.wrapping_sub(p)); + (y0, y1) +} + +#[cfg(not(target_arch = "aarch64"))] #[inline(always)] pub(crate) fn inv_butterfly_scalar( z0: u64,