From 88bcdfd5768cf99a7199f8d6010d5f479b6ba476 Mon Sep 17 00:00:00 2001 From: Grzegorz Swirski Date: Fri, 8 Dec 2023 19:34:40 +0100 Subject: [PATCH] feat: use AVX2 instructions whenever available --- src/hash/rescue/arch/mod.rs | 101 +++++++++ src/hash/rescue/arch/x86_64_avx2.rs | 325 ++++++++++++++++++++++++++++ src/hash/rescue/mod.rs | 59 +---- src/hash/rescue/rpo/mod.rs | 13 +- src/hash/rescue/rpx/mod.rs | 27 +-- 5 files changed, 442 insertions(+), 83 deletions(-) create mode 100644 src/hash/rescue/arch/mod.rs create mode 100644 src/hash/rescue/arch/x86_64_avx2.rs diff --git a/src/hash/rescue/arch/mod.rs b/src/hash/rescue/arch/mod.rs new file mode 100644 index 00000000..6706363f --- /dev/null +++ b/src/hash/rescue/arch/mod.rs @@ -0,0 +1,101 @@ +#[cfg(all(target_feature = "sve", feature = "sve"))] +pub mod optimized { + use crate::hash::rescue::STATE_WIDTH; + use crate::Felt; + + mod ffi { + #[link(name = "rpo_sve", kind = "static")] + extern "C" { + pub fn add_constants_and_apply_sbox( + state: *mut std::ffi::c_ulong, + constants: *const std::ffi::c_ulong, + ) -> bool; + pub fn add_constants_and_apply_inv_sbox( + state: *mut std::ffi::c_ulong, + constants: *const std::ffi::c_ulong, + ) -> bool; + } + } + + #[inline(always)] + pub fn add_constants_and_apply_sbox( + state: &mut [Felt; STATE_WIDTH], + ark: &[Felt; STATE_WIDTH], + ) -> bool { + unsafe { + ffi::add_constants_and_apply_sbox( + state.as_mut_ptr() as *mut u64, + ark.as_ptr() as *const u64, + ) + } + } + + #[inline(always)] + pub fn add_constants_and_apply_inv_sbox( + state: &mut [Felt; STATE_WIDTH], + ark: &[Felt; STATE_WIDTH], + ) -> bool { + unsafe { + ffi::add_constants_and_apply_inv_sbox( + state.as_mut_ptr() as *mut u64, + ark.as_ptr() as *const u64, + ) + } + } +} + +#[cfg(target_feature = "avx2")] +mod x86_64_avx2; + +#[cfg(target_feature = "avx2")] +pub mod optimized { + use super::x86_64_avx2::{apply_inv_sbox, apply_sbox}; + use crate::hash::rescue::{add_constants, STATE_WIDTH}; + use crate::Felt; + + #[inline(always)] + pub fn add_constants_and_apply_sbox( + state: &mut [Felt; STATE_WIDTH], + ark: &[Felt; STATE_WIDTH], + ) -> bool { + add_constants(state, ark); + unsafe { + apply_sbox(std::mem::transmute(state)); + } + true + } + + #[inline(always)] + pub fn add_constants_and_apply_inv_sbox( + state: &mut [Felt; STATE_WIDTH], + ark: &[Felt; STATE_WIDTH], + ) -> bool { + add_constants(state, ark); + unsafe { + apply_inv_sbox(std::mem::transmute(state)); + } + true + } +} + +#[cfg(not(any(target_feature = "avx2", all(target_feature = "sve", feature = "sve"))))] +pub mod optimized { + use crate::hash::rescue::STATE_WIDTH; + use crate::Felt; + + #[inline(always)] + pub fn add_constants_and_apply_sbox( + _state: &mut [Felt; STATE_WIDTH], + _ark: &[Felt; STATE_WIDTH], + ) -> bool { + false + } + + #[inline(always)] + pub fn add_constants_and_apply_inv_sbox( + _state: &mut [Felt; STATE_WIDTH], + _ark: &[Felt; STATE_WIDTH], + ) -> bool { + false + } +} diff --git a/src/hash/rescue/arch/x86_64_avx2.rs b/src/hash/rescue/arch/x86_64_avx2.rs new file mode 100644 index 00000000..34da0de1 --- /dev/null +++ b/src/hash/rescue/arch/x86_64_avx2.rs @@ -0,0 +1,325 @@ +use core::arch::x86_64::*; + +// The following AVX2 implementation has been copied from plonky2: +// https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs + +// Preliminary notes: +// 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily +// emulated. The method recognizes that for a + b overflowed iff (a + b) < a: +// i. res_lo = a_lo + b_lo +// ii. carry_mask = res_lo < a_lo +// iii. res_hi = a_hi + b_hi - carry_mask +// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions +// return -1 (all bits 1) for true and 0 for false. +// +// 2. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons +// by recognizing that a , $v:ident) => { + ($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2)) + }; + ($f:ident::<$l:literal>, $v1:ident, $v2:ident) => { + ($f::<$l>($v1.0, $v2.0), $f::<$l>($v1.1, $v2.1), $f::<$l>($v1.2, $v2.2)) + }; + ($f:ident, $v:ident) => { + ($f($v.0), $f($v.1), $f($v.2)) + }; + ($f:ident, $v0:ident, $v1:ident) => { + ($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2)) + }; + ($f:ident, rep $v0:ident, $v1:ident) => { + ($f($v0, $v1.0), $f($v0, $v1.1), $f($v0, $v1.2)) + }; + + ($f:ident, $v0:ident, rep $v1:ident) => { + ($f($v0.0, $v1), $f($v0.1, $v1), $f($v0.2, $v1)) + }; +} + +#[inline(always)] +unsafe fn square3( + x: (__m256i, __m256i, __m256i), +) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) { + let x_hi = { + // Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than + // bitshift. This instruction only has a floating-point flavor, so we cast to/from float. + // This is safe and free. + let x_ps = map3!(_mm256_castsi256_ps, x); + let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps); + map3!(_mm256_castps_si256, x_hi_ps) + }; + + // All pairwise multiplications. + let mul_ll = map3!(_mm256_mul_epu32, x, x); + let mul_lh = map3!(_mm256_mul_epu32, x, x_hi); + let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi); + + // Bignum addition, but mul_lh is shifted by 33 bits (not 32). + let mul_ll_hi = map3!(_mm256_srli_epi64::<33>, mul_ll); + let t0 = map3!(_mm256_add_epi64, mul_lh, mul_ll_hi); + let t0_hi = map3!(_mm256_srli_epi64::<31>, t0); + let res_hi = map3!(_mm256_add_epi64, mul_hh, t0_hi); + + // Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high + // position). + let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh); + let res_lo = map3!(_mm256_add_epi64, mul_ll, mul_lh_lo); + + (res_lo, res_hi) +} + +#[inline(always)] +unsafe fn mul3( + x: (__m256i, __m256i, __m256i), + y: (__m256i, __m256i, __m256i), +) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) { + let epsilon = _mm256_set1_epi64x(0xffffffff); + let x_hi = { + // Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than + // bitshift. This instruction only has a floating-point flavor, so we cast to/from float. + // This is safe and free. + let x_ps = map3!(_mm256_castsi256_ps, x); + let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps); + map3!(_mm256_castps_si256, x_hi_ps) + }; + let y_hi = { + let y_ps = map3!(_mm256_castsi256_ps, y); + let y_hi_ps = map3!(_mm256_movehdup_ps, y_ps); + map3!(_mm256_castps_si256, y_hi_ps) + }; + + // All four pairwise multiplications + let mul_ll = map3!(_mm256_mul_epu32, x, y); + let mul_lh = map3!(_mm256_mul_epu32, x, y_hi); + let mul_hl = map3!(_mm256_mul_epu32, x_hi, y); + let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi); + + // Bignum addition + // Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow. + let mul_ll_hi = map3!(_mm256_srli_epi64::<32>, mul_ll); + let t0 = map3!(_mm256_add_epi64, mul_hl, mul_ll_hi); + // Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow. + // Also, extract high 32 bits of t0 and add to mul_hh. + let t0_lo = map3!(_mm256_and_si256, t0, rep epsilon); + let t0_hi = map3!(_mm256_srli_epi64::<32>, t0); + let t1 = map3!(_mm256_add_epi64, mul_lh, t0_lo); + let t2 = map3!(_mm256_add_epi64, mul_hh, t0_hi); + // Lastly, extract the high 32 bits of t1 and add to t2. + let t1_hi = map3!(_mm256_srli_epi64::<32>, t1); + let res_hi = map3!(_mm256_add_epi64, t2, t1_hi); + + // Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high + // position). + let t1_lo = { + let t1_ps = map3!(_mm256_castsi256_ps, t1); + let t1_lo_ps = map3!(_mm256_moveldup_ps, t1_ps); + map3!(_mm256_castps_si256, t1_lo_ps) + }; + let res_lo = map3!(_mm256_blend_epi32::<0xaa>, mul_ll, t1_lo); + + (res_lo, res_hi) +} + +/// Addition, where the second operand is `0 <= y < 0xffffffff00000001`. +#[inline(always)] +unsafe fn add_small( + x_s: (__m256i, __m256i, __m256i), + y: (__m256i, __m256i, __m256i), +) -> (__m256i, __m256i, __m256i) { + let res_wrapped_s = map3!(_mm256_add_epi64, x_s, y); + let mask = map3!(_mm256_cmpgt_epi32, x_s, res_wrapped_s); + let wrapback_amt = map3!(_mm256_srli_epi64::<32>, mask); // EPSILON if overflowed else 0. + let res_s = map3!(_mm256_add_epi64, res_wrapped_s, wrapback_amt); + res_s +} + +#[inline(always)] +unsafe fn maybe_adj_sub(res_wrapped_s: __m256i, mask: __m256i) -> __m256i { + // The subtraction is very unlikely to overflow so we're best off branching. + // The even u32s in `mask` are meaningless, so we want to ignore them. `_mm256_testz_pd` + // branches depending on the sign bit of double-precision (64-bit) floats. Bit cast `mask` to + // floating-point (this is free). + let mask_pd = _mm256_castsi256_pd(mask); + // `_mm256_testz_pd(mask_pd, mask_pd) == 1` iff all sign bits are 0, meaning that underflow + // did not occur for any of the vector elements. + if _mm256_testz_pd(mask_pd, mask_pd) == 1 { + res_wrapped_s + } else { + branch_hint(); + // Highly unlikely: underflow did occur. Find adjustment per element and apply it. + let adj_amount = _mm256_srli_epi64::<32>(mask); // EPSILON if underflow. + _mm256_sub_epi64(res_wrapped_s, adj_amount) + } +} + +/// Addition, where the second operand is much smaller than `0xffffffff00000001`. +#[inline(always)] +unsafe fn sub_tiny( + x_s: (__m256i, __m256i, __m256i), + y: (__m256i, __m256i, __m256i), +) -> (__m256i, __m256i, __m256i) { + let res_wrapped_s = map3!(_mm256_sub_epi64, x_s, y); + let mask = map3!(_mm256_cmpgt_epi32, res_wrapped_s, x_s); + let res_s = map3!(maybe_adj_sub, res_wrapped_s, mask); + res_s +} + +#[inline(always)] +unsafe fn reduce3( + (lo0, hi0): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)), +) -> (__m256i, __m256i, __m256i) { + let sign_bit = _mm256_set1_epi64x(i64::MIN); + let epsilon = _mm256_set1_epi64x(0xffffffff); + let lo0_s = map3!(_mm256_xor_si256, lo0, rep sign_bit); + let hi_hi0 = map3!(_mm256_srli_epi64::<32>, hi0); + let lo1_s = sub_tiny(lo0_s, hi_hi0); + let t1 = map3!(_mm256_mul_epu32, hi0, rep epsilon); + let lo2_s = add_small(lo1_s, t1); + let lo2 = map3!(_mm256_xor_si256, lo2_s, rep sign_bit); + lo2 +} + +#[inline(always)] +unsafe fn mul_reduce( + a: (__m256i, __m256i, __m256i), + b: (__m256i, __m256i, __m256i), +) -> (__m256i, __m256i, __m256i) { + reduce3(mul3(a, b)) +} + +#[inline(always)] +unsafe fn square_reduce(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) { + reduce3(square3(state)) +} + +#[inline(always)] +unsafe fn exp_acc( + high: (__m256i, __m256i, __m256i), + low: (__m256i, __m256i, __m256i), + exp: usize, +) -> (__m256i, __m256i, __m256i) { + let mut result = high; + for _ in 0..exp { + result = square_reduce(result); + } + mul_reduce(result, low) +} + +#[inline(always)] +unsafe fn do_apply_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) { + let state2 = square_reduce(state); + let state4_unreduced = square3(state2); + let state3_unreduced = mul3(state2, state); + let state4 = reduce3(state4_unreduced); + let state3 = reduce3(state3_unreduced); + let state7_unreduced = mul3(state3, state4); + let state7 = reduce3(state7_unreduced); + state7 +} + +#[inline(always)] +unsafe fn do_apply_inv_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) { + // compute base^10540996611094048183 using 72 multiplications per array element + // 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111 + + // compute base^10 + let t1 = square_reduce(state); + + // compute base^100 + let t2 = square_reduce(t1); + + // compute base^100100 + let t3 = exp_acc(t2, t2, 3); + + // compute base^100100100100 + let t4 = exp_acc(t3, t3, 6); + + // compute base^100100100100100100100100 + let t5 = exp_acc(t4, t4, 12); + + // compute base^100100100100100100100100100100 + let t6 = exp_acc(t5, t3, 6); + + // compute base^1001001001001001001001001001000100100100100100100100100100100 + let t7 = exp_acc(t6, t6, 31); + + // compute base^1001001001001001001001001001000110110110110110110110110110110111 + let a = square_reduce(square_reduce(mul_reduce(square_reduce(t7), t6))); + let b = mul_reduce(t1, mul_reduce(t2, state)); + mul_reduce(a, b) +} + +#[inline(always)] +unsafe fn avx2_load(state: &[u64; 12]) -> (__m256i, __m256i, __m256i) { + ( + _mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()), + _mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()), + _mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()), + ) +} + +#[inline(always)] +unsafe fn avx2_store(buf: &mut [u64; 12], state: (__m256i, __m256i, __m256i)) { + _mm256_storeu_si256((&mut buf[0..4]).as_mut_ptr().cast::<__m256i>(), state.0); + _mm256_storeu_si256((&mut buf[4..8]).as_mut_ptr().cast::<__m256i>(), state.1); + _mm256_storeu_si256((&mut buf[8..12]).as_mut_ptr().cast::<__m256i>(), state.2); +} + +#[inline(always)] +pub unsafe fn apply_sbox(buffer: &mut [u64; 12]) { + let mut state = avx2_load(&buffer); + state = do_apply_sbox(state); + avx2_store(buffer, state); +} + +#[inline(always)] +pub unsafe fn apply_inv_sbox(buffer: &mut [u64; 12]) { + let mut state = avx2_load(&buffer); + state = do_apply_inv_sbox(state); + avx2_store(buffer, state); +} diff --git a/src/hash/rescue/mod.rs b/src/hash/rescue/mod.rs index 2fa942ac..d304064c 100644 --- a/src/hash/rescue/mod.rs +++ b/src/hash/rescue/mod.rs @@ -3,6 +3,9 @@ use super::{ }; use core::ops::Range; +mod arch; +pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox}; + mod mds; use mds::{apply_mds, MDS}; @@ -129,62 +132,6 @@ fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) { } } -// OPTIMIZATIONS -// ================================================================================================ - -#[cfg(all(target_feature = "sve", feature = "sve"))] -#[link(name = "rpo_sve", kind = "static")] -extern "C" { - fn add_constants_and_apply_sbox( - state: *mut std::ffi::c_ulong, - constants: *const std::ffi::c_ulong, - ) -> bool; - fn add_constants_and_apply_inv_sbox( - state: *mut std::ffi::c_ulong, - constants: *const std::ffi::c_ulong, - ) -> bool; -} - -#[inline(always)] -#[cfg(all(target_feature = "sve", feature = "sve"))] -fn optimized_add_constants_and_apply_sbox( - state: &mut [Felt; STATE_WIDTH], - ark: &[Felt; STATE_WIDTH], -) -> bool { - unsafe { - add_constants_and_apply_sbox(state.as_mut_ptr() as *mut u64, ark.as_ptr() as *const u64) - } -} - -#[inline(always)] -#[cfg(not(all(target_feature = "sve", feature = "sve")))] -fn optimized_add_constants_and_apply_sbox( - _state: &mut [Felt; STATE_WIDTH], - _ark: &[Felt; STATE_WIDTH], -) -> bool { - false -} - -#[inline(always)] -#[cfg(all(target_feature = "sve", feature = "sve"))] -fn optimized_add_constants_and_apply_inv_sbox( - state: &mut [Felt; STATE_WIDTH], - ark: &[Felt; STATE_WIDTH], -) -> bool { - unsafe { - add_constants_and_apply_inv_sbox(state.as_mut_ptr() as *mut u64, ark.as_ptr() as *const u64) - } -} - -#[inline(always)] -#[cfg(not(all(target_feature = "sve", feature = "sve")))] -fn optimized_add_constants_and_apply_inv_sbox( - _state: &mut [Felt; STATE_WIDTH], - _ark: &[Felt; STATE_WIDTH], -) -> bool { - false -} - #[inline(always)] fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) { state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k); diff --git a/src/hash/rescue/rpo/mod.rs b/src/hash/rescue/rpo/mod.rs index 120ea0e2..24a9f681 100644 --- a/src/hash/rescue/rpo/mod.rs +++ b/src/hash/rescue/rpo/mod.rs @@ -1,9 +1,8 @@ use super::{ - add_constants, apply_inv_sbox, apply_mds, apply_sbox, - optimized_add_constants_and_apply_inv_sbox, optimized_add_constants_and_apply_sbox, Digest, - ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1, ARK2, BINARY_CHUNK_SIZE, - CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, - NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO, + add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox, + apply_mds, apply_sbox, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1, + ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE, + INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO, }; use core::{convert::TryInto, ops::Range}; @@ -309,14 +308,14 @@ impl Rpo256 { pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) { // apply first half of RPO round apply_mds(state); - if !optimized_add_constants_and_apply_sbox(state, &ARK1[round]) { + if !add_constants_and_apply_sbox(state, &ARK1[round]) { add_constants(state, &ARK1[round]); apply_sbox(state); } // apply second half of RPO round apply_mds(state); - if !optimized_add_constants_and_apply_inv_sbox(state, &ARK2[round]) { + if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) { add_constants(state, &ARK2[round]); apply_inv_sbox(state); } diff --git a/src/hash/rescue/rpx/mod.rs b/src/hash/rescue/rpx/mod.rs index 89429dff..5458c072 100644 --- a/src/hash/rescue/rpx/mod.rs +++ b/src/hash/rescue/rpx/mod.rs @@ -1,28 +1,15 @@ use super::{ - add_constants, apply_inv_sbox, apply_mds, apply_sbox, - optimized_add_constants_and_apply_inv_sbox, optimized_add_constants_and_apply_sbox, - CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1, ARK2, - BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE, - INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO, + add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox, + apply_mds, apply_sbox, CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, + StarkField, ARK1, ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, + DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH, + STATE_WIDTH, ZERO, }; use core::{convert::TryInto, ops::Range}; mod digest; pub use digest::RpxDigest; -#[cfg(all(target_feature = "sve", feature = "sve"))] -#[link(name = "rpo_sve", kind = "static")] -extern "C" { - fn add_constants_and_apply_sbox( - state: *mut std::ffi::c_ulong, - constants: *const std::ffi::c_ulong, - ) -> bool; - fn add_constants_and_apply_inv_sbox( - state: *mut std::ffi::c_ulong, - constants: *const std::ffi::c_ulong, - ) -> bool; -} - pub type CubicExtElement = CubeExtension; // HASHER IMPLEMENTATION @@ -327,13 +314,13 @@ impl Rpx256 { #[inline(always)] pub fn apply_fb_round(state: &mut [Felt; STATE_WIDTH], round: usize) { apply_mds(state); - if !optimized_add_constants_and_apply_sbox(state, &ARK1[round]) { + if !add_constants_and_apply_sbox(state, &ARK1[round]) { add_constants(state, &ARK1[round]); apply_sbox(state); } apply_mds(state); - if !optimized_add_constants_and_apply_inv_sbox(state, &ARK2[round]) { + if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) { add_constants(state, &ARK2[round]); apply_inv_sbox(state); }