diff --git a/crates/core_simd/src/comparisons.rs b/crates/core_simd/src/comparisons.rs index edef5af3687..c481a1a95a7 100644 --- a/crates/core_simd/src/comparisons.rs +++ b/crates/core_simd/src/comparisons.rs @@ -1,56 +1,475 @@ use crate::simd::intrinsics; use crate::simd::{LaneCount, Mask, Simd, SimdElement, SupportedLaneCount}; +mod eq { + use super::*; + + pub trait Sealed: SimdElement { + /// Implementation detail of [`Simd::lanes_eq`]. + fn lanes_eq( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount; + + /// Implementation detail of [`Simd::lanes_ne`]. + fn lanes_ne( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount; + } +} + +mod ord { + use super::*; + + pub trait Sealed: SimdElement { + /// Implementation detail of [`Simd::lanes_lt`]. + fn lanes_lt( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount; + + /// Implementation detail of [`Simd::lanes_gt`]. + fn lanes_gt( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount; + + /// Implementation detail of [`Simd::lanes_le`]. + fn lanes_le( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount; + + /// Implementation detail of [`Simd::lanes_ge`]. + fn lanes_ge( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount; + /// Implementation detail of [`Simd::min`]. + fn min( + lhs: Simd, + rhs: Simd, + ) -> Simd + where + LaneCount: SupportedLaneCount; + + /// Implementation detail of [`Simd::max`]. + fn max( + lhs: Simd, + rhs: Simd, + ) -> Simd + where + LaneCount: SupportedLaneCount; + + /// Implementation detail of [`Simd::horizontal_min`]. + fn horizontal_min(x: Simd) -> Self + where + LaneCount: SupportedLaneCount; + + /// Implementation detail of [`Simd::horizontal_max`]. + fn horizontal_max(x: Simd) -> Self + where + LaneCount: SupportedLaneCount; + + /// Implementation detail of [`Simd::clamp`]. + #[inline] + fn clamp( + mut x: Simd, + min: Simd, + max: Simd, + ) -> Simd + where + LaneCount: SupportedLaneCount, + { + assert!( + Self::lanes_le(min, max).all(), + "each lane in `min` must be less than or equal to the corresponding lane in `max`", + ); + x = Self::lanes_lt(x, min).select(min, x); + x = Self::lanes_gt(x, max).select(max, x); + x + } + } +} + +/// SIMD vector element types that implement [`PartialEq`]. +pub trait SimdPartialEq: SimdElement + PartialEq + eq::Sealed {} + +/// SIMD vector element types that implement [`PartialOrd`] and can always be compared. +/// +/// Note that this trait is has one additional requirement beyond [`PartialOrd`]: all values can be +/// compared with all other values. +/// This is similar to [`Ord`], but without the requirement that comparisons are symmetric +/// (e.g. `a < b` and `a > b` can both be true for some values). +pub trait SimdPartialOrd: SimdElement + PartialOrd + ord::Sealed {} + +macro_rules! impl_integer { + { unsafe { $($type:ty),* } } => { + $( + impl eq::Sealed for $type { + #[inline] + fn lanes_eq( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount + { + unsafe { Mask::from_int_unchecked(intrinsics::simd_eq(lhs, rhs)) } + } + + #[inline] + fn lanes_ne( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount + { + unsafe { Mask::from_int_unchecked(intrinsics::simd_ne(lhs, rhs)) } + } + } + + impl SimdPartialEq for $type {} + + impl ord::Sealed for $type { + #[inline] + fn lanes_lt( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount + { + unsafe { Mask::from_int_unchecked(intrinsics::simd_lt(lhs, rhs)) } + } + + #[inline] + fn lanes_gt( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount + { + unsafe { Mask::from_int_unchecked(intrinsics::simd_gt(lhs, rhs)) } + } + + #[inline] + fn lanes_le( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount + { + unsafe { Mask::from_int_unchecked(intrinsics::simd_le(lhs, rhs)) } + } + + #[inline] + fn lanes_ge( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount + { + unsafe { Mask::from_int_unchecked(intrinsics::simd_ge(lhs, rhs)) } + } + + #[inline] + fn min( + lhs: Simd, + rhs: Simd, + ) -> Simd + where + LaneCount: SupportedLaneCount + { + // TODO consider using an intrinsic + lhs.lanes_ge(rhs).select(rhs, lhs) + } + + #[inline] + fn max( + lhs: Simd, + rhs: Simd, + ) -> Simd + where + LaneCount: SupportedLaneCount + { + // TODO consider using an intrinsic + lhs.lanes_le(rhs).select(rhs, lhs) + } + + #[inline] + fn horizontal_min(x: Simd) -> Self + where + LaneCount: SupportedLaneCount + { + unsafe { intrinsics::simd_reduce_min(x) } + } + + #[inline] + fn horizontal_max(x: Simd) -> Self + where + LaneCount: SupportedLaneCount + { + unsafe { intrinsics::simd_reduce_max(x) } + } + + } + + impl SimdPartialOrd for $type {} + )* + } +} + +macro_rules! impl_float { + { unsafe { $($type:ty),* } } => { + $( + impl eq::Sealed for $type { + #[inline] + fn lanes_eq( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount + { + unsafe { Mask::from_int_unchecked(intrinsics::simd_eq(lhs, rhs)) } + } + + #[inline] + fn lanes_ne( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount + { + unsafe { Mask::from_int_unchecked(intrinsics::simd_ne(lhs, rhs)) } + } + } + + impl SimdPartialEq for $type {} + + impl ord::Sealed for $type { + #[inline] + fn lanes_lt( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount + { + unsafe { Mask::from_int_unchecked(intrinsics::simd_lt(lhs, rhs)) } + } + + #[inline] + fn lanes_gt( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount + { + unsafe { Mask::from_int_unchecked(intrinsics::simd_gt(lhs, rhs)) } + } + + #[inline] + fn lanes_le( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount + { + unsafe { Mask::from_int_unchecked(intrinsics::simd_le(lhs, rhs)) } + } + + #[inline] + fn lanes_ge( + lhs: Simd, + rhs: Simd, + ) -> Mask + where + LaneCount: SupportedLaneCount + { + unsafe { Mask::from_int_unchecked(intrinsics::simd_ge(lhs, rhs)) } + } + + #[inline] + fn min( + lhs: Simd, + rhs: Simd, + ) -> Simd + where + LaneCount: SupportedLaneCount + { + // TODO consider using an intrinsic + lhs.is_nan() + .select(rhs, lhs.lanes_ge(rhs).select(rhs, lhs)) + } + + #[inline] + fn max( + lhs: Simd, + rhs: Simd, + ) -> Simd + where + LaneCount: SupportedLaneCount + { + // TODO consider using an intrinsic + lhs.is_nan() + .select(rhs, lhs.lanes_le(rhs).select(rhs, lhs)) + } + + #[inline] + fn horizontal_min(x: Simd) -> Self + where + LaneCount: SupportedLaneCount + { + unsafe { intrinsics::simd_reduce_min(x) } + } + + #[inline] + fn horizontal_max(x: Simd) -> Self + where + LaneCount: SupportedLaneCount + { + unsafe { intrinsics::simd_reduce_max(x) } + } + } + + impl SimdPartialOrd for $type {} + )* + } +} + +impl_integer! { unsafe { u8, u16, u32, u64, usize, i8, i16, i32, i64, isize } } +impl_float! { unsafe { f32, f64 } } + impl Simd where - T: SimdElement + PartialEq, + T: SimdPartialEq, LaneCount: SupportedLaneCount, { /// Test if each lane is equal to the corresponding lane in `other`. #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] pub fn lanes_eq(self, other: Self) -> Mask { - unsafe { Mask::from_int_unchecked(intrinsics::simd_eq(self, other)) } + T::lanes_eq(self, other) } /// Test if each lane is not equal to the corresponding lane in `other`. #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] pub fn lanes_ne(self, other: Self) -> Mask { - unsafe { Mask::from_int_unchecked(intrinsics::simd_ne(self, other)) } + T::lanes_ne(self, other) } } impl Simd where - T: SimdElement + PartialOrd, + T: SimdPartialOrd, LaneCount: SupportedLaneCount, { /// Test if each lane is less than the corresponding lane in `other`. #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] pub fn lanes_lt(self, other: Self) -> Mask { - unsafe { Mask::from_int_unchecked(intrinsics::simd_lt(self, other)) } + T::lanes_lt(self, other) } /// Test if each lane is greater than the corresponding lane in `other`. #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] pub fn lanes_gt(self, other: Self) -> Mask { - unsafe { Mask::from_int_unchecked(intrinsics::simd_gt(self, other)) } + T::lanes_gt(self, other) } /// Test if each lane is less than or equal to the corresponding lane in `other`. #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] pub fn lanes_le(self, other: Self) -> Mask { - unsafe { Mask::from_int_unchecked(intrinsics::simd_le(self, other)) } + T::lanes_le(self, other) } /// Test if each lane is greater than or equal to the corresponding lane in `other`. #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] pub fn lanes_ge(self, other: Self) -> Mask { - unsafe { Mask::from_int_unchecked(intrinsics::simd_ge(self, other)) } + T::lanes_ge(self, other) + } + + /// Returns the minimum of each lane. + /// + /// # Note + /// For `f32` and `f64`, if one of the values is `NAN`, then the other value is returned. + /// If the compared values are `0.0` and `-0.0`, the sign of the result is unspecified. + #[inline] + #[must_use = "method returns a new vector and does not mutate the original value"] + pub fn min(self, other: Self) -> Self { + T::min(self, other) + } + + /// Returns the maximum of each lane. + /// + /// # Note + /// For `f32` and `f64`, if one of the values is `NAN`, then the other value is returned. + /// If the compared values are `0.0` and `-0.0`, the sign of the result is unspecified. + #[inline] + #[must_use = "method returns a new vector and does not mutate the original value"] + pub fn max(self, other: Self) -> Self { + T::max(self, other) + } + + /// Restrict each lane to a certain interval. + /// + /// For each lane in `self`, returns the corresponding lane in `max` if the lane is + /// greater than `max`, and the corresponding lane in `min` if the lane is less + /// than `min`. Otherwise returns the lane in `self`. + /// + /// # Note + /// For `f32` and `f64`, if any value is `NAN`, then the other value is returned. + #[inline] + #[must_use = "method returns a new vector and does not mutate the original value"] + pub fn clamp(self, min: Self, max: Self) -> Self { + T::clamp(self, min, max) + } + + /// Horizontal maximum. Returns the maximum lane in the vector. + /// + /// # Note + /// For `f32` and `f64`, only returns `NAN` if all lanes are `NAN`. + /// If the vector contains both `0.0` and `-0.0` and the result is 0, the sign of the result is + /// unspecified. + #[inline] + pub fn horizontal_max(self) -> T { + T::horizontal_max(self) + } + + /// Horizontal minimum. Returns the minimum lane in the vector. + /// + /// # Note + /// For `f32` and `f64`, only returns `NAN` if all lanes are `NAN`. + /// If the vector contains both `0.0` and `-0.0` and the result is 0, the sign of the result is + /// unspecified. + #[inline] + pub fn horizontal_min(self) -> T { + T::horizontal_min(self) } } diff --git a/crates/core_simd/src/mod.rs b/crates/core_simd/src/mod.rs index ec874a22389..21838dce12c 100644 --- a/crates/core_simd/src/mod.rs +++ b/crates/core_simd/src/mod.rs @@ -25,6 +25,7 @@ mod vendor; pub mod simd { pub(crate) use crate::core_simd::intrinsics; + pub use crate::core_simd::comparisons::{SimdPartialEq, SimdPartialOrd}; pub use crate::core_simd::lane_count::{LaneCount, SupportedLaneCount}; pub use crate::core_simd::masks::*; pub use crate::core_simd::select::Select; diff --git a/crates/core_simd/src/reduction.rs b/crates/core_simd/src/reduction.rs index e79a185816b..5500b412ac1 100644 --- a/crates/core_simd/src/reduction.rs +++ b/crates/core_simd/src/reduction.rs @@ -1,6 +1,6 @@ use crate::simd::intrinsics::{ - simd_reduce_add_ordered, simd_reduce_and, simd_reduce_max, simd_reduce_min, - simd_reduce_mul_ordered, simd_reduce_or, simd_reduce_xor, + simd_reduce_add_ordered, simd_reduce_and, simd_reduce_mul_ordered, simd_reduce_or, + simd_reduce_xor, }; use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount}; use core::ops::{BitAnd, BitOr, BitXor}; @@ -22,18 +22,6 @@ macro_rules! impl_integer_reductions { pub fn horizontal_product(self) -> $scalar { unsafe { simd_reduce_mul_ordered(self, 1) } } - - /// Horizontal maximum. Returns the maximum lane in the vector. - #[inline] - pub fn horizontal_max(self) -> $scalar { - unsafe { simd_reduce_max(self) } - } - - /// Horizontal minimum. Returns the minimum lane in the vector. - #[inline] - pub fn horizontal_min(self) -> $scalar { - unsafe { simd_reduce_min(self) } - } } } } @@ -77,24 +65,6 @@ macro_rules! impl_float_reductions { unsafe { simd_reduce_mul_ordered(self, 1.) } } } - - /// Horizontal maximum. Returns the maximum lane in the vector. - /// - /// Returns values based on equality, so a vector containing both `0.` and `-0.` may - /// return either. This function will not return `NaN` unless all lanes are `NaN`. - #[inline] - pub fn horizontal_max(self) -> $scalar { - unsafe { simd_reduce_max(self) } - } - - /// Horizontal minimum. Returns the minimum lane in the vector. - /// - /// Returns values based on equality, so a vector containing both `0.` and `-0.` may - /// return either. This function will not return `NaN` unless all lanes are `NaN`. - #[inline] - pub fn horizontal_min(self) -> $scalar { - unsafe { simd_reduce_min(self) } - } } } } diff --git a/crates/core_simd/src/vector/float.rs b/crates/core_simd/src/vector/float.rs index 4a4b23238c4..0a6275fb9af 100644 --- a/crates/core_simd/src/vector/float.rs +++ b/crates/core_simd/src/vector/float.rs @@ -157,50 +157,6 @@ macro_rules! impl_float_vector { let magnitude = self.to_bits() & !Self::splat(-0.).to_bits(); Self::from_bits(sign_bit | magnitude) } - - /// Returns the minimum of each lane. - /// - /// If one of the values is `NAN`, then the other value is returned. - #[inline] - #[must_use = "method returns a new vector and does not mutate the original value"] - pub fn min(self, other: Self) -> Self { - // TODO consider using an intrinsic - self.is_nan().select( - other, - self.lanes_ge(other).select(other, self) - ) - } - - /// Returns the maximum of each lane. - /// - /// If one of the values is `NAN`, then the other value is returned. - #[inline] - #[must_use = "method returns a new vector and does not mutate the original value"] - pub fn max(self, other: Self) -> Self { - // TODO consider using an intrinsic - self.is_nan().select( - other, - self.lanes_le(other).select(other, self) - ) - } - - /// Restrict each lane to a certain interval unless it is NaN. - /// - /// For each lane in `self`, returns the corresponding lane in `max` if the lane is - /// greater than `max`, and the corresponding lane in `min` if the lane is less - /// than `min`. Otherwise returns the lane in `self`. - #[inline] - #[must_use = "method returns a new vector and does not mutate the original value"] - pub fn clamp(self, min: Self, max: Self) -> Self { - assert!( - min.lanes_le(max).all(), - "each lane in `min` must be less than or equal to the corresponding lane in `max`", - ); - let mut x = self; - x = x.lanes_lt(min).select(min, x); - x = x.lanes_gt(max).select(max, x); - x - } } }; } diff --git a/crates/core_simd/tests/ops_macros.rs b/crates/core_simd/tests/ops_macros.rs index 43ddde4c55e..7c9d67cf95f 100644 --- a/crates/core_simd/tests/ops_macros.rs +++ b/crates/core_simd/tests/ops_macros.rs @@ -163,6 +163,40 @@ macro_rules! impl_common_integer_tests { Ok(()) }); } + + fn min() { + test_helpers::test_binary_elementwise( + &Vector::::min, + &std::cmp::Ord::min, + &|_, _| true, + ) + } + + fn max() { + test_helpers::test_binary_elementwise( + &Vector::::max, + &std::cmp::Ord::max, + &|_, _| true, + ) + } + + fn clamp() { + test_helpers::test_3(&|value: [Scalar; LANES], mut min: [Scalar; LANES], mut max: [Scalar; LANES]| { + for (min, max) in min.iter_mut().zip(max.iter_mut()) { + if max < min { + core::mem::swap(min, max); + } + } + + let mut result_scalar = [Scalar::default(); LANES]; + for i in 0..LANES { + result_scalar[i] = value[i].clamp(min[i], max[i]); + } + let result_vector = Vector::from_array(value).clamp(min.into(), max.into()).to_array(); + test_helpers::prop_assert_biteq!(result_scalar, result_vector); + Ok(()) + }) + } } } }