diff --git a/etc/function-definitions.json b/etc/function-definitions.json index 64a775ba9..bca58402f 100644 --- a/etc/function-definitions.json +++ b/etc/function-definitions.json @@ -342,12 +342,14 @@ }, "fma": { "sources": [ + "src/math/arch/aarch64.rs", "src/math/fma.rs" ], "type": "f64" }, "fmaf": { "sources": [ + "src/math/arch/aarch64.rs", "src/math/fma_wide.rs" ], "type": "f32" @@ -806,6 +808,7 @@ }, "rintf16": { "sources": [ + "src/math/arch/aarch64.rs", "src/math/rint.rs" ], "type": "f16" @@ -928,6 +931,7 @@ }, "sqrt": { "sources": [ + "src/math/arch/aarch64.rs", "src/math/arch/i686.rs", "src/math/arch/wasm32.rs", "src/math/generic/sqrt.rs", @@ -937,6 +941,7 @@ }, "sqrtf": { "sources": [ + "src/math/arch/aarch64.rs", "src/math/arch/i686.rs", "src/math/arch/wasm32.rs", "src/math/generic/sqrt.rs", @@ -953,6 +958,7 @@ }, "sqrtf16": { "sources": [ + "src/math/arch/aarch64.rs", "src/math/generic/sqrt.rs", "src/math/sqrtf16.rs" ], diff --git a/src/math/arch/aarch64.rs b/src/math/arch/aarch64.rs index 374ec11bf..020bb731c 100644 --- a/src/math/arch/aarch64.rs +++ b/src/math/arch/aarch64.rs @@ -1,33 +1,115 @@ -use core::arch::aarch64::{ - float32x2_t, float64x1_t, vdup_n_f32, vdup_n_f64, vget_lane_f32, vget_lane_f64, vrndn_f32, - vrndn_f64, -}; +//! Architecture-specific support for aarch64 with neon. -pub fn rint(x: f64) -> f64 { - // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module. - let x_vec: float64x1_t = unsafe { vdup_n_f64(x) }; +use core::arch::asm; - // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module. - let result_vec: float64x1_t = unsafe { vrndn_f64(x_vec) }; +pub fn fma(mut x: f64, y: f64, z: f64) -> f64 { + // SAFETY: `fmadd` is available with neon and has no side effects. + unsafe { + asm!( + "fmadd {x:d}, {x:d}, {y:d}, {z:d}", + x = inout(vreg) x, + y = in(vreg) y, + z = in(vreg) z, + options(nomem, nostack, pure) + ); + } + x +} - // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module. - let result: f64 = unsafe { vget_lane_f64::<0>(result_vec) }; +pub fn fmaf(mut x: f32, y: f32, z: f32) -> f32 { + // SAFETY: `fmadd` is available with neon and has no side effects. + unsafe { + asm!( + "fmadd {x:s}, {x:s}, {y:s}, {z:s}", + x = inout(vreg) x, + y = in(vreg) y, + z = in(vreg) z, + options(nomem, nostack, pure) + ); + } + x +} - result +pub fn rint(mut x: f64) -> f64 { + // SAFETY: `frintn` is available with neon and has no side effects. + // + // `frintn` is always round-to-nearest which does not match the C specification, but Rust does + // not support rounding modes. + unsafe { + asm!( + "frintn {x:d}, {x:d}", + x = inout(vreg) x, + options(nomem, nostack, pure) + ); + } + x } -pub fn rintf(x: f32) -> f32 { - // There's a scalar form of this instruction (FRINTN) but core::arch doesn't expose it, so we - // have to use the vector form and drop the other lanes afterwards. +pub fn rintf(mut x: f32) -> f32 { + // SAFETY: `frintn` is available with neon and has no side effects. + // + // `frintn` is always round-to-nearest which does not match the C specification, but Rust does + // not support rounding modes. + unsafe { + asm!( + "frintn {x:s}, {x:s}", + x = inout(vreg) x, + options(nomem, nostack, pure) + ); + } + x +} - // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module. - let x_vec: float32x2_t = unsafe { vdup_n_f32(x) }; +#[cfg(all(f16_enabled, target_feature = "fp16"))] +pub fn rintf16(mut x: f16) -> f16 { + // SAFETY: `frintn` is available for `f16` with `fp16` (implies `neon`) and has no side effects. + // + // `frintn` is always round-to-nearest which does not match the C specification, but Rust does + // not support rounding modes. + unsafe { + asm!( + "frintn {x:h}, {x:h}", + x = inout(vreg) x, + options(nomem, nostack, pure) + ); + } + x +} - // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module. - let result_vec: float32x2_t = unsafe { vrndn_f32(x_vec) }; +pub fn sqrt(mut x: f64) -> f64 { + // SAFETY: `fsqrt` is available with neon and has no side effects. + unsafe { + asm!( + "fsqrt {x:d}, {x:d}", + x = inout(vreg) x, + options(nomem, nostack, pure) + ); + } + x +} - // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module. - let result: f32 = unsafe { vget_lane_f32::<0>(result_vec) }; +pub fn sqrtf(mut x: f32) -> f32 { + // SAFETY: `fsqrt` is available with neon and has no side effects. + unsafe { + asm!( + "fsqrt {x:s}, {x:s}", + x = inout(vreg) x, + options(nomem, nostack, pure) + ); + } + x +} - result +#[cfg(all(f16_enabled, target_feature = "fp16"))] +pub fn sqrtf16(mut x: f16) -> f16 { + // SAFETY: `fsqrt` is available for `f16` with `fp16` (implies `neon`) and has no + // side effects. + unsafe { + asm!( + "fsqrt {x:h}, {x:h}", + x = inout(vreg) x, + options(nomem, nostack, pure) + ); + } + x } diff --git a/src/math/arch/mod.rs b/src/math/arch/mod.rs index 091d7650a..d9f2aad66 100644 --- a/src/math/arch/mod.rs +++ b/src/math/arch/mod.rs @@ -18,12 +18,25 @@ cfg_if! { mod i686; pub use i686::{sqrt, sqrtf}; } else if #[cfg(all( - target_arch = "aarch64", // TODO: also arm64ec? - target_feature = "neon", - target_endian = "little", // see https://github.com/rust-lang/stdarch/issues/1484 + any(target_arch = "aarch64", target_arch = "arm64ec"), + target_feature = "neon" ))] { mod aarch64; - pub use aarch64::{rint, rintf}; + + pub use aarch64::{ + fma, + fmaf, + rint, + rintf, + sqrt, + sqrtf, + }; + + #[cfg(all(f16_enabled, target_feature = "fp16"))] + pub use aarch64::{ + rintf16, + sqrtf16, + }; } } diff --git a/src/math/fma.rs b/src/math/fma.rs index 049f573cc..789b0836a 100644 --- a/src/math/fma.rs +++ b/src/math/fma.rs @@ -9,6 +9,12 @@ use super::{CastFrom, CastInto, Float, Int, MinInt}; /// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision). #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] pub fn fma(x: f64, y: f64, z: f64) -> f64 { + select_implementation! { + name: fma, + use_arch: all(target_arch = "aarch64", target_feature = "neon"), + args: x, y, z, + } + fma_round(x, y, z, Round::Nearest).val } diff --git a/src/math/fma_wide.rs b/src/math/fma_wide.rs index d0cf33baf..8e908a14f 100644 --- a/src/math/fma_wide.rs +++ b/src/math/fma_wide.rs @@ -17,6 +17,12 @@ pub(crate) fn fmaf16(_x: f16, _y: f16, _z: f16) -> f16 { /// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision). #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] pub fn fmaf(x: f32, y: f32, z: f32) -> f32 { + select_implementation! { + name: fmaf, + use_arch: all(target_arch = "aarch64", target_feature = "neon"), + args: x, y, z, + } + fma_wide_round(x, y, z, Round::Nearest).val } diff --git a/src/math/rint.rs b/src/math/rint.rs index 8a5cbeab4..e1c32c943 100644 --- a/src/math/rint.rs +++ b/src/math/rint.rs @@ -4,6 +4,12 @@ use super::support::Round; #[cfg(f16_enabled)] #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] pub fn rintf16(x: f16) -> f16 { + select_implementation! { + name: rintf16, + use_arch: all(target_arch = "aarch64", target_feature = "fp16"), + args: x, + } + super::generic::rint_round(x, Round::Nearest).val } @@ -13,8 +19,8 @@ pub fn rintf(x: f32) -> f32 { select_implementation! { name: rintf, use_arch: any( + all(target_arch = "aarch64", target_feature = "neon"), all(target_arch = "wasm32", intrinsics_enabled), - all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"), ), args: x, } @@ -28,8 +34,8 @@ pub fn rint(x: f64) -> f64 { select_implementation! { name: rint, use_arch: any( + all(target_arch = "aarch64", target_feature = "neon"), all(target_arch = "wasm32", intrinsics_enabled), - all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"), ), args: x, } diff --git a/src/math/sqrt.rs b/src/math/sqrt.rs index 0e1d0cd2c..2bfc42bcf 100644 --- a/src/math/sqrt.rs +++ b/src/math/sqrt.rs @@ -4,6 +4,7 @@ pub fn sqrt(x: f64) -> f64 { select_implementation! { name: sqrt, use_arch: any( + all(target_arch = "aarch64", target_feature = "neon"), all(target_arch = "wasm32", intrinsics_enabled), target_feature = "sse2" ), diff --git a/src/math/sqrtf.rs b/src/math/sqrtf.rs index 2e69a4b66..c28a705e3 100644 --- a/src/math/sqrtf.rs +++ b/src/math/sqrtf.rs @@ -4,6 +4,7 @@ pub fn sqrtf(x: f32) -> f32 { select_implementation! { name: sqrtf, use_arch: any( + all(target_arch = "aarch64", target_feature = "neon"), all(target_arch = "wasm32", intrinsics_enabled), target_feature = "sse2" ), diff --git a/src/math/sqrtf16.rs b/src/math/sqrtf16.rs index 549bf902c..7bedb7f8b 100644 --- a/src/math/sqrtf16.rs +++ b/src/math/sqrtf16.rs @@ -1,5 +1,11 @@ /// The square root of `x` (f16). #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] pub fn sqrtf16(x: f16) -> f16 { + select_implementation! { + name: sqrtf16, + use_arch: all(target_arch = "aarch64", target_feature = "fp16"), + args: x, + } + return super::generic::sqrt(x); }