Skip to content

Commit db41351

Browse files
committed
Auto merge of rust-lang#98866 - nagisa:nagisa/align-offset-wroom, r=Mark-Simulacrum
Add a special case for align_offset /w stride != 1 This generalizes the previous `stride == 1` special case to apply to any situation where the requested alignment is divisible by the stride. This in turn allows the test case from rust-lang#98809 produce ideal assembly, along the lines of: leaq 15(%rdi), %rax andq $-16, %rax This also produces pretty high quality code for situations where the alignment of the input pointer isn’t known: pub unsafe fn ptr_u32(slice: *const u32) -> *const u32 { slice.offset(slice.align_offset(16) as isize) } // => movl %edi, %eax andl $3, %eax leaq 15(%rdi), %rcx andq $-16, %rcx subq %rdi, %rcx shrq $2, %rcx negq %rax sbbq %rax, %rax orq %rcx, %rax leaq (%rdi,%rax,4), %rax Here LLVM is smart enough to replace the `usize::MAX` special case with a branch-less bitwise-OR approach, where the mask is constructed using the neg and sbb instructions. This appears to work across various architectures I’ve tried. This change ends up introducing more branches and code in situations where there is less knowledge of the arguments. For example when the requested alignment is entirely unknown. This use-case was never really a focus of this function, so I’m not particularly worried, especially since llvm-mca is saying that the new code is still appreciably faster, despite all the new branching. Fixes rust-lang#98809. Sadly, this does not help with rust-lang#72356.
2 parents d5e7f47 + 62a182c commit db41351

File tree

3 files changed

+137
-57
lines changed

3 files changed

+137
-57
lines changed

library/core/src/ptr/mod.rs

+53-30
Original file line numberDiff line numberDiff line change
@@ -1557,11 +1557,10 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
15571557
// FIXME(#75598): Direct use of these intrinsics improves codegen significantly at opt-level <=
15581558
// 1, where the method versions of these operations are not inlined.
15591559
use intrinsics::{
1560-
unchecked_shl, unchecked_shr, unchecked_sub, wrapping_add, wrapping_mul, wrapping_sub,
1560+
cttz_nonzero, exact_div, unchecked_rem, unchecked_shl, unchecked_shr, unchecked_sub,
1561+
wrapping_add, wrapping_mul, wrapping_sub,
15611562
};
15621563

1563-
let addr = p.addr();
1564-
15651564
/// Calculate multiplicative modular inverse of `x` modulo `m`.
15661565
///
15671566
/// This implementation is tailored for `align_offset` and has following preconditions:
@@ -1611,36 +1610,61 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
16111610
}
16121611
}
16131612

1613+
let addr = p.addr();
16141614
let stride = mem::size_of::<T>();
16151615
// SAFETY: `a` is a power-of-two, therefore non-zero.
16161616
let a_minus_one = unsafe { unchecked_sub(a, 1) };
1617-
if stride == 1 {
1618-
// `stride == 1` case can be computed more simply through `-p (mod a)`, but doing so
1619-
// inhibits LLVM's ability to select instructions like `lea`. Instead we compute
1617+
1618+
if stride == 0 {
1619+
// SPECIAL_CASE: handle 0-sized types. No matter how many times we step, the address will
1620+
// stay the same, so no offset will be able to align the pointer unless it is already
1621+
// aligned. This branch _will_ be optimized out as `stride` is known at compile-time.
1622+
let p_mod_a = addr & a_minus_one;
1623+
return if p_mod_a == 0 { 0 } else { usize::MAX };
1624+
}
1625+
1626+
// SAFETY: `stride == 0` case has been handled by the special case above.
1627+
let a_mod_stride = unsafe { unchecked_rem(a, stride) };
1628+
if a_mod_stride == 0 {
1629+
// SPECIAL_CASE: In cases where the `a` is divisible by `stride`, byte offset to align a
1630+
// pointer can be computed more simply through `-p (mod a)`. In the off-chance the byte
1631+
// offset is not a multiple of `stride`, the input pointer was misaligned and no pointer
1632+
// offset will be able to produce a `p` aligned to the specified `a`.
16201633
//
1621-
// round_up_to_next_alignment(p, a) - p
1634+
// The naive `-p (mod a)` equation inhibits LLVM's ability to select instructions
1635+
// like `lea`. We compute `(round_up_to_next_alignment(p, a) - p)` instead. This
1636+
// redistributes operations around the load-bearing, but pessimizing `and` instruction
1637+
// sufficiently for LLVM to be able to utilize the various optimizations it knows about.
16221638
//
1623-
// which distributes operations around the load-bearing, but pessimizing `and` sufficiently
1624-
// for LLVM to be able to utilize the various optimizations it knows about.
1625-
return wrapping_sub(wrapping_add(addr, a_minus_one) & wrapping_sub(0, a), addr);
1626-
}
1639+
// LLVM handles the branch here particularly nicely. If this branch needs to be evaluated
1640+
// at runtime, it will produce a mask `if addr_mod_stride == 0 { 0 } else { usize::MAX }`
1641+
// in a branch-free way and then bitwise-OR it with whatever result the `-p mod a`
1642+
// computation produces.
1643+
1644+
// SAFETY: `stride == 0` case has been handled by the special case above.
1645+
let addr_mod_stride = unsafe { unchecked_rem(addr, stride) };
16271646

1628-
let pmoda = addr & a_minus_one;
1629-
if pmoda == 0 {
1630-
// Already aligned. Yay!
1631-
return 0;
1632-
} else if stride == 0 {
1633-
// If the pointer is not aligned, and the element is zero-sized, then no amount of
1634-
// elements will ever align the pointer.
1635-
return usize::MAX;
1647+
return if addr_mod_stride == 0 {
1648+
let aligned_address = wrapping_add(addr, a_minus_one) & wrapping_sub(0, a);
1649+
let byte_offset = wrapping_sub(aligned_address, addr);
1650+
// SAFETY: `stride` is non-zero. This is guaranteed to divide exactly as well, because
1651+
// addr has been verified to be aligned to the original type’s alignment requirements.
1652+
unsafe { exact_div(byte_offset, stride) }
1653+
} else {
1654+
usize::MAX
1655+
};
16361656
}
16371657

1638-
let smoda = stride & a_minus_one;
1658+
// GENERAL_CASE: From here on we’re handling the very general case where `addr` may be
1659+
// misaligned, there isn’t an obvious relationship between `stride` and `a` that we can take an
1660+
// advantage of, etc. This case produces machine code that isn’t particularly high quality,
1661+
// compared to the special cases above. The code produced here is still within the realm of
1662+
// miracles, given the situations this case has to deal with.
1663+
16391664
// SAFETY: a is power-of-two hence non-zero. stride == 0 case is handled above.
1640-
let gcdpow = unsafe { intrinsics::cttz_nonzero(stride).min(intrinsics::cttz_nonzero(a)) };
1665+
let gcdpow = unsafe { cttz_nonzero(stride).min(cttz_nonzero(a)) };
16411666
// SAFETY: gcdpow has an upper-bound that’s at most the number of bits in a usize.
16421667
let gcd = unsafe { unchecked_shl(1usize, gcdpow) };
1643-
16441668
// SAFETY: gcd is always greater or equal to 1.
16451669
if addr & unsafe { unchecked_sub(gcd, 1) } == 0 {
16461670
// This branch solves for the following linear congruence equation:
@@ -1656,14 +1680,13 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
16561680
// ` p' + s'o = 0 mod a' `
16571681
// ` o = (a' - (p' mod a')) * (s'^-1 mod a') `
16581682
//
1659-
// The first term is "the relative alignment of `p` to `a`" (divided by the `g`), the second
1660-
// term is "how does incrementing `p` by `s` bytes change the relative alignment of `p`" (again
1661-
// divided by `g`).
1662-
// Division by `g` is necessary to make the inverse well formed if `a` and `s` are not
1663-
// co-prime.
1683+
// The first term is "the relative alignment of `p` to `a`" (divided by the `g`), the
1684+
// second term is "how does incrementing `p` by `s` bytes change the relative alignment of
1685+
// `p`" (again divided by `g`). Division by `g` is necessary to make the inverse well
1686+
// formed if `a` and `s` are not co-prime.
16641687
//
16651688
// Furthermore, the result produced by this solution is not "minimal", so it is necessary
1666-
// to take the result `o mod lcm(s, a)`. We can replace `lcm(s, a)` with just a `a'`.
1689+
// to take the result `o mod lcm(s, a)`. This `lcm(s, a)` is the same as `a'`.
16671690

16681691
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
16691692
// `a`.
@@ -1673,11 +1696,11 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
16731696
let a2minus1 = unsafe { unchecked_sub(a2, 1) };
16741697
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
16751698
// `a`.
1676-
let s2 = unsafe { unchecked_shr(smoda, gcdpow) };
1699+
let s2 = unsafe { unchecked_shr(stride & a_minus_one, gcdpow) };
16771700
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
16781701
// `a`. Furthermore, the subtraction cannot overflow, because `a2 = a >> gcdpow` will
16791702
// always be strictly greater than `(p % a) >> gcdpow`.
1680-
let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(pmoda, gcdpow)) };
1703+
let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(addr & a_minus_one, gcdpow)) };
16811704
// SAFETY: `a2` is a power-of-two, as proven above. `s2` is strictly less than `a2`
16821705
// because `(s % a) >> gcdpow` is strictly less than `a >> gcdpow`.
16831706
return wrapping_mul(minusp2, unsafe { mod_inv(s2, a2) }) & a2minus1;

library/core/tests/ptr.rs

+36-27
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ fn align_offset_zst() {
359359
}
360360

361361
#[test]
362-
fn align_offset_stride1() {
362+
fn align_offset_stride_one() {
363363
// For pointers of stride = 1, the pointer can always be aligned. The offset is equal to
364364
// number of bytes.
365365
let mut align = 1;
@@ -380,24 +380,8 @@ fn align_offset_stride1() {
380380
}
381381

382382
#[test]
383-
fn align_offset_weird_strides() {
384-
#[repr(packed)]
385-
struct A3(u16, u8);
386-
struct A4(u32);
387-
#[repr(packed)]
388-
struct A5(u32, u8);
389-
#[repr(packed)]
390-
struct A6(u32, u16);
391-
#[repr(packed)]
392-
struct A7(u32, u16, u8);
393-
#[repr(packed)]
394-
struct A8(u32, u32);
395-
#[repr(packed)]
396-
struct A9(u32, u32, u8);
397-
#[repr(packed)]
398-
struct A10(u32, u32, u16);
399-
400-
unsafe fn test_weird_stride<T>(ptr: *const T, align: usize) -> bool {
383+
fn align_offset_various_strides() {
384+
unsafe fn test_stride<T>(ptr: *const T, align: usize) -> bool {
401385
let numptr = ptr as usize;
402386
let mut expected = usize::MAX;
403387
// Naive but definitely correct way to find the *first* aligned element of stride::<T>.
@@ -431,14 +415,39 @@ fn align_offset_weird_strides() {
431415
while align < limit {
432416
for ptr in 1usize..4 * align {
433417
unsafe {
434-
x |= test_weird_stride::<A3>(ptr::invalid::<A3>(ptr), align);
435-
x |= test_weird_stride::<A4>(ptr::invalid::<A4>(ptr), align);
436-
x |= test_weird_stride::<A5>(ptr::invalid::<A5>(ptr), align);
437-
x |= test_weird_stride::<A6>(ptr::invalid::<A6>(ptr), align);
438-
x |= test_weird_stride::<A7>(ptr::invalid::<A7>(ptr), align);
439-
x |= test_weird_stride::<A8>(ptr::invalid::<A8>(ptr), align);
440-
x |= test_weird_stride::<A9>(ptr::invalid::<A9>(ptr), align);
441-
x |= test_weird_stride::<A10>(ptr::invalid::<A10>(ptr), align);
418+
#[repr(packed)]
419+
struct A3(u16, u8);
420+
x |= test_stride::<A3>(ptr::invalid::<A3>(ptr), align);
421+
422+
struct A4(u32);
423+
x |= test_stride::<A4>(ptr::invalid::<A4>(ptr), align);
424+
425+
#[repr(packed)]
426+
struct A5(u32, u8);
427+
x |= test_stride::<A5>(ptr::invalid::<A5>(ptr), align);
428+
429+
#[repr(packed)]
430+
struct A6(u32, u16);
431+
x |= test_stride::<A6>(ptr::invalid::<A6>(ptr), align);
432+
433+
#[repr(packed)]
434+
struct A7(u32, u16, u8);
435+
x |= test_stride::<A7>(ptr::invalid::<A7>(ptr), align);
436+
437+
#[repr(packed)]
438+
struct A8(u32, u32);
439+
x |= test_stride::<A8>(ptr::invalid::<A8>(ptr), align);
440+
441+
#[repr(packed)]
442+
struct A9(u32, u32, u8);
443+
x |= test_stride::<A9>(ptr::invalid::<A9>(ptr), align);
444+
445+
#[repr(packed)]
446+
struct A10(u32, u32, u16);
447+
x |= test_stride::<A10>(ptr::invalid::<A10>(ptr), align);
448+
449+
x |= test_stride::<u32>(ptr::invalid::<u32>(ptr), align);
450+
x |= test_stride::<u128>(ptr::invalid::<u128>(ptr), align);
442451
}
443452
}
444453
align = (align + 1).next_power_of_two();

src/test/assembly/align_offset.rs

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// assembly-output: emit-asm
2+
// compile-flags: -Copt-level=1
3+
// only-x86_64
4+
// min-llvm-version: 14.0
5+
#![crate_type="rlib"]
6+
7+
// CHECK-LABEL: align_offset_byte_ptr
8+
// CHECK: leaq 31
9+
// CHECK: andq $-32
10+
// CHECK: subq
11+
#[no_mangle]
12+
pub fn align_offset_byte_ptr(ptr: *const u8) -> usize {
13+
ptr.align_offset(32)
14+
}
15+
16+
// CHECK-LABEL: align_offset_byte_slice
17+
// CHECK: leaq 31
18+
// CHECK: andq $-32
19+
// CHECK: subq
20+
#[no_mangle]
21+
pub fn align_offset_byte_slice(slice: &[u8]) -> usize {
22+
slice.as_ptr().align_offset(32)
23+
}
24+
25+
// CHECK-LABEL: align_offset_word_ptr
26+
// CHECK: leaq 31
27+
// CHECK: andq $-32
28+
// CHECK: subq
29+
// CHECK: shrq
30+
// This `ptr` is not known to be aligned, so it is required to check if it is at all possible to
31+
// align. LLVM applies a simple mask.
32+
// CHECK: orq
33+
#[no_mangle]
34+
pub fn align_offset_word_ptr(ptr: *const u32) -> usize {
35+
ptr.align_offset(32)
36+
}
37+
38+
// CHECK-LABEL: align_offset_word_slice
39+
// CHECK: leaq 31
40+
// CHECK: andq $-32
41+
// CHECK: subq
42+
// CHECK: shrq
43+
// `slice` is known to be aligned, so `!0` is not possible as a return
44+
// CHECK-NOT: orq
45+
#[no_mangle]
46+
pub fn align_offset_word_slice(slice: &[u32]) -> usize {
47+
slice.as_ptr().align_offset(32)
48+
}

0 commit comments

Comments
 (0)