Skip to content

Commit

Permalink
core: fix const ptr::swap_nonoverlapping when there are pointers at o…
Browse files Browse the repository at this point in the history
…dd offsets in the type
  • Loading branch information
RalfJung authored and gitbot committed Feb 20, 2025
1 parent ce0e54d commit 4cf5da8
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 48 deletions.
14 changes: 7 additions & 7 deletions core/src/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3795,15 +3795,15 @@ where
/// See [`const_eval_select()`] for the rules and requirements around that intrinsic.
pub(crate) macro const_eval_select {
(
@capture { $($arg:ident : $ty:ty = $val:expr),* $(,)? } $( -> $ret:ty )? :
@capture$([$($binders:tt)*])? { $($arg:ident : $ty:ty = $val:expr),* $(,)? } $( -> $ret:ty )? :
if const
$(#[$compiletime_attr:meta])* $compiletime:block
else
$(#[$runtime_attr:meta])* $runtime:block
) => {
// Use the `noinline` arm, after adding explicit `inline` attributes
$crate::intrinsics::const_eval_select!(
@capture { $($arg : $ty = $val),* } $(-> $ret)? :
@capture$([$($binders)*])? { $($arg : $ty = $val),* } $(-> $ret)? :
#[noinline]
if const
#[inline] // prevent codegen on this function
Expand All @@ -3817,20 +3817,20 @@ pub(crate) macro const_eval_select {
},
// With a leading #[noinline], we don't add inline attributes
(
@capture { $($arg:ident : $ty:ty = $val:expr),* $(,)? } $( -> $ret:ty )? :
@capture$([$($binders:tt)*])? { $($arg:ident : $ty:ty = $val:expr),* $(,)? } $( -> $ret:ty )? :
#[noinline]
if const
$(#[$compiletime_attr:meta])* $compiletime:block
else
$(#[$runtime_attr:meta])* $runtime:block
) => {{
$(#[$runtime_attr])*
fn runtime($($arg: $ty),*) $( -> $ret )? {
fn runtime$(<$($binders)*>)?($($arg: $ty),*) $( -> $ret )? {
$runtime
}

$(#[$compiletime_attr])*
const fn compiletime($($arg: $ty),*) $( -> $ret )? {
const fn compiletime$(<$($binders)*>)?($($arg: $ty),*) $( -> $ret )? {
// Don't warn if one of the arguments is unused.
$(let _ = $arg;)*

Expand All @@ -3842,14 +3842,14 @@ pub(crate) macro const_eval_select {
// We support leaving away the `val` expressions for *all* arguments
// (but not for *some* arguments, that's too tricky).
(
@capture { $($arg:ident : $ty:ty),* $(,)? } $( -> $ret:ty )? :
@capture$([$($binders:tt)*])? { $($arg:ident : $ty:ty),* $(,)? } $( -> $ret:ty )? :
if const
$(#[$compiletime_attr:meta])* $compiletime:block
else
$(#[$runtime_attr:meta])* $runtime:block
) => {
$crate::intrinsics::const_eval_select!(
@capture { $($arg : $ty = $arg),* } $(-> $ret)? :
@capture$([$($binders)*])? { $($arg : $ty = $arg),* } $(-> $ret)? :
if const
$(#[$compiletime_attr])* $compiletime
else
Expand Down
73 changes: 42 additions & 31 deletions core/src/ptr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@
#![allow(clippy::not_unsafe_ptr_arg_deref)]

use crate::cmp::Ordering;
use crate::intrinsics::const_eval_select;
use crate::marker::FnPtr;
use crate::mem::{self, MaybeUninit, SizedTypeProperties};
use crate::{fmt, hash, intrinsics, ub_checks};
Expand Down Expand Up @@ -1074,25 +1075,6 @@ pub const unsafe fn swap<T>(x: *mut T, y: *mut T) {
#[rustc_const_unstable(feature = "const_swap_nonoverlapping", issue = "133668")]
#[rustc_diagnostic_item = "ptr_swap_nonoverlapping"]
pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
#[allow(unused)]
macro_rules! attempt_swap_as_chunks {
($ChunkTy:ty) => {
if mem::align_of::<T>() >= mem::align_of::<$ChunkTy>()
&& mem::size_of::<T>() % mem::size_of::<$ChunkTy>() == 0
{
let x: *mut $ChunkTy = x.cast();
let y: *mut $ChunkTy = y.cast();
let count = count * (mem::size_of::<T>() / mem::size_of::<$ChunkTy>());
// SAFETY: these are the same bytes that the caller promised were
// ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
// The `if` condition above ensures that we're not violating
// alignment requirements, and that the division is exact so
// that we don't lose any bytes off the end.
return unsafe { swap_nonoverlapping_simple_untyped(x, y, count) };
}
};
}

ub_checks::assert_unsafe_precondition!(
check_language_ub,
"ptr::swap_nonoverlapping requires that both pointer arguments are aligned and non-null \
Expand All @@ -1111,19 +1093,48 @@ pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
}
);

// Split up the slice into small power-of-two-sized chunks that LLVM is able
// to vectorize (unless it's a special type with more-than-pointer alignment,
// because we don't want to pessimize things like slices of SIMD vectors.)
if mem::align_of::<T>() <= mem::size_of::<usize>()
&& (!mem::size_of::<T>().is_power_of_two()
|| mem::size_of::<T>() > mem::size_of::<usize>() * 2)
{
attempt_swap_as_chunks!(usize);
attempt_swap_as_chunks!(u8);
}
const_eval_select!(
@capture[T] { x: *mut T, y: *mut T, count: usize }:
if const {
// At compile-time we want to always copy this in chunks of `T`, to ensure that if there
// are pointers inside `T` we will copy them in one go rather than trying to copy a part
// of a pointer (which would not work).
// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
} else {
macro_rules! attempt_swap_as_chunks {
($ChunkTy:ty) => {
if mem::align_of::<T>() >= mem::align_of::<$ChunkTy>()
&& mem::size_of::<T>() % mem::size_of::<$ChunkTy>() == 0
{
let x: *mut $ChunkTy = x.cast();
let y: *mut $ChunkTy = y.cast();
let count = count * (mem::size_of::<T>() / mem::size_of::<$ChunkTy>());
// SAFETY: these are the same bytes that the caller promised were
// ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
// The `if` condition above ensures that we're not violating
// alignment requirements, and that the division is exact so
// that we don't lose any bytes off the end.
return unsafe { swap_nonoverlapping_simple_untyped(x, y, count) };
}
};
}

// Split up the slice into small power-of-two-sized chunks that LLVM is able
// to vectorize (unless it's a special type with more-than-pointer alignment,
// because we don't want to pessimize things like slices of SIMD vectors.)
if mem::align_of::<T>() <= mem::size_of::<usize>()
&& (!mem::size_of::<T>().is_power_of_two()
|| mem::size_of::<T>() > mem::size_of::<usize>() * 2)
{
attempt_swap_as_chunks!(usize);
attempt_swap_as_chunks!(u8);
}

// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
}
)
}

/// Same behavior and safety conditions as [`swap_nonoverlapping`]
Expand Down
1 change: 1 addition & 0 deletions core/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#![feature(const_black_box)]
#![feature(const_eval_select)]
#![feature(const_swap)]
#![feature(const_swap_nonoverlapping)]
#![feature(const_trait_impl)]
#![feature(core_intrinsics)]
#![feature(core_io_borrowed_buf)]
Expand Down
63 changes: 53 additions & 10 deletions core/tests/ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,10 @@ fn swap_copy_untyped() {
}

#[test]
fn test_const_copy() {
fn test_const_copy_ptr() {
// `copy` and `copy_nonoverlapping` are thin layers on top of intrinsics. Ensure they correctly
// deal with pointers even when the pointers cross the boundary from one "element" being copied
// to another.
const {
let ptr1 = &1;
let mut ptr2 = &666;
Expand Down Expand Up @@ -899,21 +902,61 @@ fn test_const_copy() {
}

#[test]
fn test_const_swap() {
fn test_const_swap_ptr() {
// The `swap` functions are implemented in the library, they are not primitives.
// Only `swap_nonoverlapping` takes a count; pointers that cross multiple elements
// are *not* supported.
// We put the pointer at an odd offset in the type and copy them as an array of bytes,
// which should catch most of the ways that the library implementation can get it wrong.

#[cfg(target_pointer_width = "32")]
type HalfPtr = i16;
#[cfg(target_pointer_width = "64")]
type HalfPtr = i32;

#[repr(C, packed)]
#[allow(unused)]
struct S {
f1: HalfPtr,
// Crucially this field is at an offset that is not a multiple of the pointer size.
ptr: &'static i32,
// Make sure the entire type does not have a power-of-2 size:
// make it 3 pointers in size. This used to hit a bug in `swap_nonoverlapping`.
f2: [HalfPtr; 3],
}

// Ensure the entire thing is usize-aligned, so in principle this
// looks like it could be eligible for a `usize` copying loop.
#[cfg_attr(target_pointer_width = "32", repr(align(4)))]
#[cfg_attr(target_pointer_width = "64", repr(align(8)))]
struct A(S);

const {
let mut ptr1 = &1;
let mut ptr2 = &666;
let mut s1 = A(S { ptr: &1, f1: 0, f2: [0; 3] });
let mut s2 = A(S { ptr: &666, f1: 0, f2: [0; 3] });

// Swap ptr1 and ptr2, bytewise. `swap` does not take a count
// so the best we can do is use an array.
type T = [u8; mem::size_of::<&i32>()];
// Swap ptr1 and ptr2, as an array.
type T = [u8; mem::size_of::<A>()];
unsafe {
ptr::swap(ptr::from_mut(&mut ptr1).cast::<T>(), ptr::from_mut(&mut ptr2).cast::<T>());
ptr::swap(ptr::from_mut(&mut s1).cast::<T>(), ptr::from_mut(&mut s2).cast::<T>());
}

// Make sure they still work.
assert!(*ptr1 == 666);
assert!(*ptr2 == 1);
assert!(*s1.0.ptr == 666);
assert!(*s2.0.ptr == 1);

// Swap them back, again as an array.
unsafe {
ptr::swap_nonoverlapping(
ptr::from_mut(&mut s1).cast::<T>(),
ptr::from_mut(&mut s2).cast::<T>(),
1,
);
}

// Make sure they still work.
assert!(*s1.0.ptr == 1);
assert!(*s2.0.ptr == 666);
};
}

Expand Down

0 comments on commit 4cf5da8

Please sign in to comment.