Skip to content

Commit

Permalink
Make AVX512IFMA opt-in backend
Browse files Browse the repository at this point in the history
  • Loading branch information
pinkforest committed Aug 29, 2024
1 parent 0964f80 commit 0c260c7
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 36 deletions.
2 changes: 1 addition & 1 deletion curve25519-dalek/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ curve25519-dalek-derive = { version = "0.1", path = "../curve25519-dalek-derive"
level = "warn"
check-cfg = [
'cfg(allow_unused_unsafe)',
'cfg(curve25519_dalek_backend, values("fiat", "serial", "simd"))',
'cfg(curve25519_dalek_backend, values("fiat", "serial", "simd", "unstable_avx512"))',
'cfg(curve25519_dalek_diagnostics, values("build"))',
'cfg(curve25519_dalek_bits, values("32", "64"))',
'cfg(nightly)',
Expand Down
60 changes: 41 additions & 19 deletions curve25519-dalek/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ fn main() {

println!("cargo:rustc-cfg=curve25519_dalek_bits=\"{curve25519_dalek_bits}\"");

if rustc_version::version_meta()
let nightly = if rustc_version::version_meta()
.expect("failed to detect rustc version")
.channel
== rustc_version::Channel::Nightly
{
println!("cargo:rustc-cfg=nightly");
}
true
} else {
false
};

let rustc_version = rustc_version::version().expect("failed to detect rustc version");
if rustc_version.major == 1 && rustc_version.minor <= 64 {
Expand All @@ -51,25 +54,44 @@ fn main() {
}

// Backend overrides / defaults
let curve25519_dalek_backend =
match std::env::var("CARGO_CFG_CURVE25519_DALEK_BACKEND").as_deref() {
Ok("fiat") => "fiat",
Ok("serial") => "serial",
Ok("simd") => {
// simd can only be enabled on x86_64 & 64bit target_pointer_width
match is_capable_simd(&target_arch, curve25519_dalek_bits) {
true => "simd",
// If override is not possible this must result to compile error
// See: issues/532
false => panic!("Could not override curve25519_dalek_backend to simd"),
let curve25519_dalek_backend = match std::env::var("CARGO_CFG_CURVE25519_DALEK_BACKEND")
.as_deref()
{
Ok("fiat") => "fiat",
Ok("serial") => "serial",
Ok("simd") => {
// simd can only be enabled on x86_64 & 64bit target_pointer_width
match is_capable_simd(&target_arch, curve25519_dalek_bits) {
true => "simd",
// If override is not possible this must result to compile error
// See: issues/532
false => panic!("Could not override curve25519_dalek_backend to simd"),
}
}
Ok("unstable_avx512") if nightly => {
// simd can only be enabled on x86_64 & 64bit target_pointer_width
match is_capable_simd(&target_arch, curve25519_dalek_bits) {
true => {
// In addition enable Avx2 fallback through simd stable backend
// NOTE: Compiler permits duplicate / multi value on the same key
println!("cargo:rustc-cfg=curve25519_dalek_backend=\"simd\"");

"unstable_avx512"
}
// If override is not possible this must result to compile error
// See: issues/532
false => panic!("Could not override curve25519_dalek_backend to unstable_avx512"),
}
// default between serial / simd (if potentially capable)
_ => match is_capable_simd(&target_arch, curve25519_dalek_bits) {
true => "simd",
false => "serial",
},
};
}
Ok("unstable_avx512") if !nightly => {
panic!("Coult not override curve25519_dalek_backend to unstable_avx512 as this is nigthly only.");
}
// default between serial / simd (if potentially capable)
_ => match is_capable_simd(&target_arch, curve25519_dalek_bits) {
true => "simd",
false => "serial",
},
};
println!("cargo:rustc-cfg=curve25519_dalek_backend=\"{curve25519_dalek_backend}\"");
}

Expand Down
20 changes: 10 additions & 10 deletions curve25519-dalek/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ pub mod vector;
enum BackendKind {
#[cfg(curve25519_dalek_backend = "simd")]
Avx2,
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
Avx512,
Serial,
}

#[inline]
fn get_selected_backend() -> BackendKind {
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
{
cpufeatures::new!(cpuid_avx512, "avx512ifma", "avx512vl");
let token_avx512: cpuid_avx512::InitToken = cpuid_avx512::init();
Expand Down Expand Up @@ -88,7 +88,7 @@ where
#[cfg(curve25519_dalek_backend = "simd")]
BackendKind::Avx2 =>
vector::scalar_mul::pippenger::spec_avx2::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
BackendKind::Avx512 =>
vector::scalar_mul::pippenger::spec_avx512ifma_avx512vl::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
BackendKind::Serial =>
Expand All @@ -100,7 +100,7 @@ where
pub(crate) enum VartimePrecomputedStraus {
#[cfg(curve25519_dalek_backend = "simd")]
Avx2(vector::scalar_mul::precomputed_straus::spec_avx2::VartimePrecomputedStraus),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
Avx512ifma(
vector::scalar_mul::precomputed_straus::spec_avx512ifma_avx512vl::VartimePrecomputedStraus,
),
Expand All @@ -120,7 +120,7 @@ impl VartimePrecomputedStraus {
#[cfg(curve25519_dalek_backend = "simd")]
BackendKind::Avx2 =>
VartimePrecomputedStraus::Avx2(vector::scalar_mul::precomputed_straus::spec_avx2::VartimePrecomputedStraus::new(static_points)),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
BackendKind::Avx512 =>
VartimePrecomputedStraus::Avx512ifma(vector::scalar_mul::precomputed_straus::spec_avx512ifma_avx512vl::VartimePrecomputedStraus::new(static_points)),
BackendKind::Serial =>
Expand Down Expand Up @@ -150,7 +150,7 @@ impl VartimePrecomputedStraus {
dynamic_scalars,
dynamic_points,
),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
VartimePrecomputedStraus::Avx512ifma(inner) => inner.optional_mixed_multiscalar_mul(
static_scalars,
dynamic_scalars,
Expand Down Expand Up @@ -181,7 +181,7 @@ where
BackendKind::Avx2 => {
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul::<I, J>(scalars, points)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
BackendKind::Avx512 => {
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul::<I, J>(
scalars, points,
Expand Down Expand Up @@ -210,7 +210,7 @@ where
scalars, points,
)
}
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
BackendKind::Avx512 => {
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::optional_multiscalar_mul::<
I,
Expand All @@ -228,7 +228,7 @@ pub fn variable_base_mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint
match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
BackendKind::Avx2 => vector::scalar_mul::variable_base::spec_avx2::mul(point, scalar),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
BackendKind::Avx512 => {
vector::scalar_mul::variable_base::spec_avx512ifma_avx512vl::mul(point, scalar)
}
Expand All @@ -242,7 +242,7 @@ pub fn vartime_double_base_mul(a: &Scalar, A: &EdwardsPoint, b: &Scalar) -> Edwa
match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
BackendKind::Avx2 => vector::scalar_mul::vartime_double_base::spec_avx2::mul(a, A, b),
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
BackendKind::Avx512 => {
vector::scalar_mul::vartime_double_base::spec_avx512ifma_avx512vl::mul(a, A, b)
}
Expand Down
2 changes: 1 addition & 1 deletion curve25519-dalek/src/backend/vector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub mod packed_simd;

pub mod avx2;

#[cfg(nightly)]
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
pub mod ifma;

pub mod scalar_mul;
5 changes: 4 additions & 1 deletion curve25519-dalek/src/backend/vector/scalar_mul/pippenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

#[curve25519_dalek_derive::unsafe_target_feature_specialize(
"avx2",
conditional("avx512ifma,avx512vl", nightly)
conditional(
"avx512ifma,avx512vl",
all(curve25519_dalek_backend = "unstable_avx512", nightly)
)
)]
pub mod spec {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

#[curve25519_dalek_derive::unsafe_target_feature_specialize(
"avx2",
conditional("avx512ifma,avx512vl", nightly)
conditional(
"avx512ifma,avx512vl",
all(curve25519_dalek_backend = "unstable_avx512", nightly)
)
)]
pub mod spec {

Expand Down
5 changes: 4 additions & 1 deletion curve25519-dalek/src/backend/vector/scalar_mul/straus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

#[curve25519_dalek_derive::unsafe_target_feature_specialize(
"avx2",
conditional("avx512ifma,avx512vl", nightly)
conditional(
"avx512ifma,avx512vl",
all(curve25519_dalek_backend = "unstable_avx512", nightly)
)
)]
pub mod spec {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

#[curve25519_dalek_derive::unsafe_target_feature_specialize(
"avx2",
conditional("avx512ifma,avx512vl", nightly)
conditional(
"avx512ifma,avx512vl",
all(curve25519_dalek_backend = "unstable_avx512", nightly)
)
)]
pub mod spec {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

#[curve25519_dalek_derive::unsafe_target_feature_specialize(
"avx2",
conditional("avx512ifma,avx512vl", nightly)
conditional(
"avx512ifma,avx512vl",
all(curve25519_dalek_backend = "unstable_avx512", nightly)
)
)]
pub mod spec {

Expand Down

0 comments on commit 0c260c7

Please sign in to comment.