Skip to content

Commit

Permalink
Add unchecked_disjoint_bitor with fallback intrinsic implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
scottmcm authored and gitbot committed Feb 20, 2025
1 parent e9cbf64 commit 52234f8
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 3 deletions.
40 changes: 40 additions & 0 deletions core/src/intrinsics/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,43 @@ impl const CarryingMulAdd for i128 {
(low, high)
}
}

#[const_trait]
#[rustc_const_unstable(feature = "core_intrinsics_fallbacks", issue = "none")]
pub trait DisjointBitOr: Copy + 'static {
/// This is always just `assume((self & other) == 0); self | other`.
///
/// It's essential that the assume is there so that this is sufficient to
/// specify the UB for MIRI, rather than it needing to re-implement it.
///
/// # Safety
/// See [`super::disjoint_bitor`].
unsafe fn disjoint_bitor(self, other: Self) -> Self;
}
macro_rules! zero {
(bool) => {
false
};
($t:ident) => {
0
};
}
macro_rules! impl_disjoint_bitor {
($($t:ident,)+) => {$(
#[rustc_const_unstable(feature = "core_intrinsics_fallbacks", issue = "none")]
impl const DisjointBitOr for $t {
#[inline]
unsafe fn disjoint_bitor(self, other: Self) -> Self {
// SAFETY: our precondition is that there are no bits in common,
// so this is just telling that to the backend.
unsafe { super::assume((self & other) == zero!($t)) };
self | other
}
}
)+};
}
impl_disjoint_bitor! {
bool,
u8, u16, u32, u64, u128, usize,
i8, i16, i32, i64, i128, isize,
}
19 changes: 19 additions & 0 deletions core/src/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3248,6 +3248,25 @@ pub const fn three_way_compare<T: Copy>(_lhs: T, _rhss: T) -> crate::cmp::Orderi
unimplemented!()
}

/// Combine two values which have no bits in common.
///
/// This allows the backend to implement it as `a + b` *or* `a | b`,
/// depending which us easier to implement on a specific target.
///
/// # Safety
///
/// Requires that `(a & b) == 0`, or equivalently that `(a | b) == (a + b)`.
///
/// Otherwise it's immediate UB.
#[rustc_const_unstable(feature = "disjoint_bitor", issue = "135758")]
#[rustc_nounwind]
#[cfg_attr(not(bootstrap), rustc_intrinsic)]
#[miri::intrinsic_fallback_is_spec] // the fallbacks all `assume` to tell MIRI
pub const unsafe fn disjoint_bitor<T: ~const fallback::DisjointBitOr>(a: T, b: T) -> T {
// SAFETY: same preconditions as this function.
unsafe { fallback::DisjointBitOr::disjoint_bitor(a, b) }
}

/// Performs checked integer addition.
///
/// Note that, unlike most intrinsics, this is safe to call;
Expand Down
1 change: 1 addition & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
#![feature(const_eval_select)]
#![feature(core_intrinsics)]
#![feature(coverage_attribute)]
#![feature(disjoint_bitor)]
#![feature(internal_impls_macro)]
#![feature(ip)]
#![feature(is_ascii_octdigit)]
Expand Down
59 changes: 56 additions & 3 deletions core/src/num/uint_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,52 @@ macro_rules! uint_impl {
self % rhs
}

/// Same value as
#[doc = concat!("`<", stringify!($SelfT), " as BitOr>::bitor(self, other)`")]
/// but UB if any bit position is set in both inputs.
///
/// This is a situational μoptimization for places where you'd rather use
/// addition on some platforms and bitwise or on other platforms, based on
/// exactly which instructions combine better with whatever else you're
/// doing. Note that there's no reason to bother using this for places
/// where it's clear from the operations involved that they can't overlap.
/// For example, if you're combining `u16`s into a `u32` with
/// `((a as u32) << 16) | (b as u32)`, that's fine, as the backend will
/// know those sides of the `|` are disjoint without needing help.
///
/// # Examples
///
/// ```
/// #![feature(disjoint_bitor)]
///
/// // SAFETY: `1` and `4` have no bits in common.
/// unsafe {
#[doc = concat!(" assert_eq!(1_", stringify!($SelfT), ".unchecked_disjoint_bitor(4), 5);")]
/// }
/// ```
///
/// # Safety
///
/// Requires that `(self | other) == 0`, otherwise it's immediate UB.
///
/// Equivalently, requires that `(self | other) == (self + other)`.
#[unstable(feature = "disjoint_bitor", issue = "135758")]
#[rustc_const_unstable(feature = "disjoint_bitor", issue = "135758")]
#[inline]
pub const unsafe fn unchecked_disjoint_bitor(self, other: Self) -> Self {
assert_unsafe_precondition!(
check_language_ub,
concat!(stringify!($SelfT), "::unchecked_disjoint_bitor cannot have overlapping bits"),
(
lhs: $SelfT = self,
rhs: $SelfT = other,
) => (lhs & rhs) == 0,
);

// SAFETY: Same precondition
unsafe { intrinsics::disjoint_bitor(self, other) }
}

/// Returns the logarithm of the number with respect to an arbitrary base,
/// rounded down.
///
Expand Down Expand Up @@ -2346,15 +2392,22 @@ macro_rules! uint_impl {
/// assert_eq!((sum1, sum0), (9, 6));
/// ```
#[unstable(feature = "bigint_helper_methods", issue = "85532")]
#[rustc_const_unstable(feature = "bigint_helper_methods", issue = "85532")]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn carrying_add(self, rhs: Self, carry: bool) -> (Self, bool) {
// note: longer-term this should be done via an intrinsic, but this has been shown
// to generate optimal code for now, and LLVM doesn't have an equivalent intrinsic
let (a, b) = self.overflowing_add(rhs);
let (c, d) = a.overflowing_add(carry as $SelfT);
(c, b | d)
let (a, c1) = self.overflowing_add(rhs);
let (b, c2) = a.overflowing_add(carry as $SelfT);
// Ideally LLVM would know this is disjoint without us telling them,
// but it doesn't <https://github.com/llvm/llvm-project/issues/118162>
// SAFETY: Only one of `c1` and `c2` can be set.
// For c1 to be set we need to have overflowed, but if we did then
// `a` is at most `MAX-1`, which means that `c2` cannot possibly
// overflow because it's adding at most `1` (since it came from `bool`)
(b, unsafe { intrinsics::disjoint_bitor(c1, c2) })
}

/// Calculates `self` + `rhs` with a signed `rhs`.
Expand Down

0 comments on commit 52234f8

Please sign in to comment.