Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add from_box_bytes and box_bytes_of with BoxBytes type #211

Merged
merged 5 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 115 additions & 1 deletion src/allocation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use alloc::{
vec,
vec::Vec,
};
use core::ops::{Deref, DerefMut};

/// As [`try_cast_box`](try_cast_box), but unwraps for you.
#[inline]
Expand Down Expand Up @@ -686,4 +687,117 @@ pub trait TransparentWrapperAlloc<Inner: ?Sized>:
}
}

impl<I: ?Sized, T: ?Sized + TransparentWrapper<I>> TransparentWrapperAlloc<I> for T {}
impl<I: ?Sized, T: ?Sized + TransparentWrapper<I>> TransparentWrapperAlloc<I>
for T
{
}

/// As `Box<[u8]>`, but remembers the original alignment.
pub struct BoxBytes {
// SAFETY: `ptr` is owned, was allocated with `layout`, and points to
// `layout.size()` initialized bytes.
ptr: NonNull<u8>,
layout: Layout,
}

impl Deref for BoxBytes {
type Target = [u8];

fn deref(&self) -> &Self::Target {
// SAFETY: See type invariant.
unsafe {
core::slice::from_raw_parts(self.ptr.as_ptr(), self.layout.size())
}
}
}

impl DerefMut for BoxBytes {
fn deref_mut(&mut self) -> &mut Self::Target {
// SAFETY: See type invariant.
unsafe {
core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.layout.size())
}
}
}

impl Drop for BoxBytes {
fn drop(&mut self) {
// SAFETY: See type invariant.
unsafe { alloc::alloc::dealloc(self.ptr.as_ptr(), self.layout) };
}
}

impl<T: NoUninit> From<Box<T>> for BoxBytes {
fn from(value: Box<T>) -> Self {
let layout = Layout::new::<T>();
let ptr = Box::into_raw(value) as *mut u8;
// SAFETY: Box::into_raw() returns a non-null pointer.
let ptr = unsafe { NonNull::new_unchecked(ptr) };
BoxBytes { ptr, layout }
}
}

/// Re-interprets `Box<T>` as `BoxBytes`.
#[inline]
pub fn box_bytes_of<T: NoUninit>(input: Box<T>) -> BoxBytes {
input.into()
}

/// Re-interprets `BoxBytes` as `Box<T>`.
///
/// ## Panics
///
/// This is [`try_from_box_bytes`] but will panic on error and the input will be
/// dropped.
#[inline]
pub fn from_box_bytes<T: AnyBitPattern>(input: BoxBytes) -> Box<T> {
try_from_box_bytes(input).map_err(|(error, _)| error).unwrap()
}

/// Re-interprets `BoxBytes` as `Box<T>`.
///
/// ## Panics
///
/// * If the input isn't aligned for the new type
/// * If the input's length isn’t exactly the size of the new type
#[inline]
pub fn try_from_box_bytes<T: AnyBitPattern>(
input: BoxBytes,
) -> Result<Box<T>, (PodCastError, BoxBytes)> {
let layout = Layout::new::<T>();
if input.layout.align() != layout.align() {
return Err((PodCastError::AlignmentMismatch, input));
} else if input.layout.size() != layout.size() {
return Err((PodCastError::SizeMismatch, input));
} else {
let (ptr, _) = input.into_raw_parts();
// SAFETY: See type invariant.
Ok(unsafe { Box::from_raw(ptr.as_ptr() as *mut T) })
}
}

impl BoxBytes {
/// Constructs a `BoxBytes` from its raw parts.
///
/// # Safety
///
/// The pointer is owned, has been allocated with the provided layout, and
/// points to `layout.size()` initialized bytes.
pub unsafe fn from_raw_parts(ptr: NonNull<u8>, layout: Layout) -> Self {
BoxBytes { ptr, layout }
}

/// Deconstructs a `BoxBytes` into its raw parts.
///
/// The pointer is owned, has been allocated with the provided layout, and
/// points to `layout.size()` initialized bytes.
pub fn into_raw_parts(self) -> (NonNull<u8>, Layout) {
let me = ManuallyDrop::new(self);
(me.ptr, me.layout)
}

/// Returns the original layout.
pub fn layout(&self) -> Layout {
self.layout
}
}
60 changes: 60 additions & 0 deletions tests/std_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! depend on that can go here.

use bytemuck::*;
use core::num::NonZeroU8;

#[test]
fn test_transparent_vtabled() {
Expand Down Expand Up @@ -44,3 +45,62 @@ fn test_zero_sized_box_alloc() {
unsafe impl Zeroable for Empty {}
let _: Box<Empty> = try_zeroed_box().unwrap();
}

#[test]
#[cfg(feature = "extern_crate_alloc")]
fn test_try_from_box_bytes() {
// Different layout: target alignment is greater than source alignment.
assert_eq!(
try_from_box_bytes::<u32>(Box::new([0u8; 4]).into()).map_err(|(x, _)| x),
Err(PodCastError::AlignmentMismatch)
);

// Different layout: target alignment is less than source alignment.
assert_eq!(
try_from_box_bytes::<u32>(Box::new(0u64).into()).map_err(|(x, _)| x),
Err(PodCastError::AlignmentMismatch)
);

// Different layout: target size is greater than source size.
assert_eq!(
try_from_box_bytes::<[u32; 2]>(Box::new(0u32).into()).map_err(|(x, _)| x),
Err(PodCastError::SizeMismatch)
);

// Different layout: target size is less than source size.
assert_eq!(
try_from_box_bytes::<u32>(Box::new([0u32; 2]).into()).map_err(|(x, _)| x),
Err(PodCastError::SizeMismatch)
);

// Round trip: alignment is equal to size.
assert_eq!(*from_box_bytes::<u32>(Box::new(1000u32).into()), 1000u32);

// Round trip: alignment is divider of size.
assert_eq!(&*from_box_bytes::<[u8; 5]>(Box::new(*b"hello").into()), b"hello");

// It's ok for T to have uninitialized bytes.
#[cfg(feature = "derive")]
{
#[derive(Debug, Copy, Clone, PartialEq, Eq, AnyBitPattern)]
struct Foo(u8, u16);
assert_eq!(
*from_box_bytes::<Foo>(Box::new([0xc5c5u16; 2]).into()),
Foo(0xc5u8, 0xc5c5u16)
);
}
}

#[test]
#[cfg(feature = "extern_crate_alloc")]
fn test_box_bytes_of() {
assert_eq!(&*box_bytes_of(Box::new(*b"hello")), b"hello");

#[cfg(target_endian = "big")]
assert_eq!(&*box_bytes_of(Box::new(0x12345678)), b"\x12\x34\x56\x78");
#[cfg(target_endian = "little")]
assert_eq!(&*box_bytes_of(Box::new(0x12345678)), b"\x78\x56\x34\x12");

// It's ok for T to have invalid bit patterns.
assert_eq!(&*box_bytes_of(Box::new(NonZeroU8::new(0xc5))), b"\xc5");
}