Skip to content

Commit

Permalink
closes 49. hipextmallocwithflags (#107)
Browse files Browse the repository at this point in the history
* memory methods on struct. Add malloc_with_flags()
  • Loading branch information
smedegaard authored Nov 26, 2024
1 parent f57d39e commit c6f8cc3
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 42 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ semver = "1.0.23"
uuid = "1.11.0"
log = "0.4"
env_logger = "0.10"
bitflags = "2.6.0"

[build-dependencies]
# For build script
Expand Down
20 changes: 1 addition & 19 deletions src/runtime/memory.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,3 @@
use super::sys;
use crate::types::{Device, MemoryPointer, Result};

/// Allocates memory on a HIP device/accelerator.
///
/// This function allocates a block of `size` bytes of device memory and returns a
/// MemoryPointer that safely manages the memory allocation. The memory will be
/// automatically freed when the MemoryPointer is dropped.
///
/// If 0 is passed for `size`, `Ok(std::ptr::null_mut)` is returned.
///
/// # Arguments
/// * `size` - Size of memory allocation in bytes
///
/// # Returns
/// * `Ok(MemoryPointer)` - Handle to allocated device memory
/// * `Err(HipError)` - Error occurred during allocation
/// ```
pub fn malloc<T>(size: usize) -> Result<MemoryPointer<T>> {
MemoryPointer::new(size)
}
use crate::DeviceMallocFlag;
11 changes: 11 additions & 0 deletions src/types/flags.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use bitflags::bitflags;

bitflags::bitflags! {
pub struct DeviceMallocFlag: u32 {
const DEFAULT = 0x0;
const FINEGRAINED = 0x1;
const SIGNAL_MEMORY = 0x2;
const UNCACHED = 0x3;
const CONTIGUOUS = 0x4;
}
}
118 changes: 95 additions & 23 deletions src/types/memory.rs
Original file line number Diff line number Diff line change
@@ -1,45 +1,90 @@
use super::flags::DeviceMallocFlag;
use super::{HipError, HipResult, Result};
use crate::sys;

/// A wrapper for device memory allocated on the GPU.
/// Automatically frees the memory when dropped.
pub struct MemoryPointer<T> {
ptr: *mut T,
pointer: *mut T,
size: usize,
}

impl<T> MemoryPointer<T> {
pub fn new(size: usize) -> Result<Self> {
/// Private function that holds common logic for the
/// memory allocation functions.
///
/// Takes the size to allocate and
fn allocate_with_fn<F>(size: usize, alloc_fn: F) -> Result<Self>
where
F: FnOnce(*mut *mut std::ffi::c_void, usize) -> u32,
{
// Handle zero size allocation according to spec
if size == 0 {
return Ok(MemoryPointer {
ptr: std::ptr::null_mut(),
pointer: std::ptr::null_mut(),
size: 0,
});
}

let mut ptr = std::ptr::null_mut();

let code = unsafe {
sys::hipMalloc(
&mut ptr as *mut *mut T as *mut *mut std::ffi::c_void,
size * std::mem::size_of::<T>(),
)
};
let code = alloc_fn(
&mut ptr as *mut *mut T as *mut *mut std::ffi::c_void,
size * std::mem::size_of::<T>(),
);

let pointer = Self {
ptr: ptr as *mut T,
pointer: ptr as *mut T,
size,
};

(pointer, code).to_result()
}

/// Allocates memory on a HIP device/accelerator.
///
/// This function allocates a block of `size` bytes of device memory and returns a
/// MemoryPointer that safely manages the memory allocation. The memory will be
/// automatically freed when the MemoryPointer is dropped.
///
/// If 0 is passed for `size`, `Ok(std::ptr::null_mut)` is returned.
///
/// # Arguments
/// * `size` - Size of memory allocation in bytes
///
/// # Returns
/// * `Ok(MemoryPointer)` - Handle to allocated device memory
/// * `Err(HipError)` - Error occurred during allocation
/// ```
pub fn alloc(size: usize) -> Result<Self> {
Self::allocate_with_fn(size, |ptr, size| unsafe { sys::hipMalloc(ptr, size) })
}

/// Allocates memory on the default accelerator with specified allocation flags.
///
/// # Arguments
/// * `size` - The requested memory size in bytes
/// * `flag` - The memory allocation flag. Must be one of: DeviceMallocDefault,
/// DeviceMallocFinegrained, DeviceMallocUncached, or MallocSignalMemory
///
/// # Returns
/// * `Ok(MemoryPointer<T>)` - Successfully allocated memory pointer
/// * `Err(_)` - If allocation fails due to out of memory or invalid flags
///
/// # Notes
/// * If size is 0, returns null pointer with success status
/// * Invalid flags will result in hipErrorInvalidValue error
///
pub fn alloc_with_flag(size: usize, flag: DeviceMallocFlag) -> Result<Self> {
Self::allocate_with_fn(size, |ptr, size| unsafe {
sys::hipExtMallocWithFlags(ptr, size, flag.bits())
})
}

/// Returns the raw memory pointer.
///
/// Note: This pointer cannot be directly dereferenced from CPU code.
pub fn as_ptr(&self) -> *mut T {
self.ptr
self.pointer
}

/// Returns the size in bytes of the allocated memory
Expand All @@ -52,7 +97,7 @@ impl<T> MemoryPointer<T> {
impl<T> Drop for MemoryPointer<T> {
fn drop(&mut self) {
unsafe {
let code = sys::hipFree(self.ptr as *mut std::ffi::c_void);
let code = sys::hipFree(self.pointer as *mut std::ffi::c_void);
if code != 0 {
let error = HipError::new(code);
log::error!("MemoryPointer failed to free memory: {}", error);
Expand All @@ -69,37 +114,64 @@ mod tests {

#[test]
fn test_new_zero_size() {
let result = MemoryPointer::<u8>::new(0).unwrap();
assert!(result.ptr.is_null());
let result = MemoryPointer::<u8>::alloc(0).unwrap();
assert!(result.pointer.is_null());
assert_eq!(result.size, 0);
}

#[test]
fn test_new_valid_size() {
let size = 1024;
let result = MemoryPointer::<u8>::new(size).unwrap();
assert!(!result.ptr.is_null());
let result = MemoryPointer::<u8>::alloc(size).unwrap();
assert!(!result.pointer.is_null());
assert_eq!(result.size, size);
}

#[test]
fn test_new_different_types() {
// Test with different sized types
let result = MemoryPointer::<u32>::new(100).unwrap();
assert!(!result.ptr.is_null());
let result = MemoryPointer::<u32>::alloc(100).unwrap();
assert!(!result.pointer.is_null());

let result = MemoryPointer::<f64>::new(100).unwrap();
assert!(!result.ptr.is_null());
let result = MemoryPointer::<f64>::alloc(100).unwrap();
assert!(!result.pointer.is_null());
}

#[test]
fn test_large_allocation() {
let mb = 1024 * 1024;
let size = 3000 * mb;
println!("Attempting to allocate {} bytes", size);
let result = MemoryPointer::<u8>::new(size);
let result = MemoryPointer::<u8>::alloc(size);

sleep(Duration::from_secs(5));
assert!(!result.unwrap().ptr.is_null());
assert!(!result.unwrap().pointer.is_null());
}

#[test]
fn test_alloc_with_flag_success() {
let size = 1024;
let result = MemoryPointer::<u8>::alloc_with_flag(size, DeviceMallocFlag::DEFAULT);
assert!(result.is_ok());
let ptr = result.unwrap();
assert!(!ptr.pointer.is_null());
}

#[test]
fn test_alloc_with_flag_zero_size() {
let result = MemoryPointer::<u8>::alloc_with_flag(0, DeviceMallocFlag::DEFAULT);
assert!(result.is_ok());
let ptr = result.unwrap();
assert!(ptr.pointer.is_null());
}

#[test]
fn test_alloc_with_combined_flag() {
let size = 1024;
let flag = DeviceMallocFlag::DEFAULT | DeviceMallocFlag::FINEGRAINED;
let result = MemoryPointer::<u8>::alloc_with_flag(size, flag);
assert!(result.is_ok());
let ptr = result.unwrap();
assert!(!ptr.pointer.is_null());
}
}
2 changes: 2 additions & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod device;
mod flags;
mod memory;
mod result;

pub use device::*;
pub use flags::*;
pub use memory::*;
pub use result::*;

0 comments on commit c6f8cc3

Please sign in to comment.