From 4e9b532d31ce8b4f10a08314682ab0ad5b1b1366 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 6 Dec 2024 16:02:19 +0100 Subject: [PATCH] add Device::get_default_mem_pool() --- src/core/device.rs | 42 ++++++++++++++++++++++++++++++++++++++++-- src/core/memory.rs | 27 +++++++++++++++++++++++++++ src/core/result.rs | 2 ++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/core/device.rs b/src/core/device.rs index 0ef9adf..ddb162e 100644 --- a/src/core/device.rs +++ b/src/core/device.rs @@ -1,5 +1,5 @@ use super::sys; -use super::{DeviceP2PAttribute, HipError, HipErrorKind, HipResult, PCIBusId, Result}; +use super::{DeviceP2PAttribute, HipError, HipErrorKind, HipResult, MemPool, PCIBusId, Result}; use semver::Version; use std::ffi::CStr; use std::i32; @@ -150,13 +150,30 @@ impl Device { /// * There was an error retrieving the PCI bus ID pub fn get_device_pci_bus_id(&self) -> Result { let mut pci_bus_id = PCIBusId::new(); - unsafe { let code = sys::hipDeviceGetPCIBusId(pci_bus_id.as_mut_ptr(), pci_bus_id.len(), self.id); (pci_bus_id, code).to_result() } } + + /// Gets the default memory pool associated with this device. + /// + /// # Returns + /// * `Result` - The default memory pool for the device if successful + /// + /// # Errors + /// Returns `HipError` if: + /// * The device ID is invalid + /// * The operation is not supported on this device/platform + /// * There was an error retrieving the memory pool + pub fn get_default_mem_pool(&self) -> Result { + let mut mem_pool = std::ptr::null_mut(); + unsafe { + let code = sys::hipDeviceGetDefaultMemPool(&mut mem_pool, self.id); + (MemPool::from_raw(mem_pool), code).to_result() + } + } } /// Free Functions @@ -300,6 +317,27 @@ pub fn get_device_by_pci_bus_id(mut pci_bus_id: PCIBusId) -> Result { mod tests { use super::*; + #[test] + fn test_get_default_mem_pool() { + let device = Device::new(0); + let result = device.get_default_mem_pool(); + + // The operation might not be supported on all devices/platforms + match result { + Ok(mem_pool) => { + println!("Successfully retrieved default memory pool"); + assert!(!mem_pool.is_null()); + } + Err(e) => { + // Check if the error is "not supported" which is acceptable + if e.kind != HipErrorKind::NotSupported { + panic!("Unexpected error getting default memory pool: {:?}", e); + } + println!("Memory pools not supported on this device/platform"); + } + } + } + #[test] fn test_get_device_by_pci_bus_id() { let device = Device::new(0); diff --git a/src/core/memory.rs b/src/core/memory.rs index 09e09ed..3698fdc 100644 --- a/src/core/memory.rs +++ b/src/core/memory.rs @@ -210,6 +210,33 @@ impl Drop for MemoryPointer { } } +/// Represents a HIP memory pool handle +#[derive(Debug)] +pub struct MemPool { + handle: sys::hipMemPool_t, +} + +impl MemPool { + /// Creates a new MemPool from a raw handle + pub(crate) fn from_raw(handle: sys::hipMemPool_t) -> Self { + MemPool { handle } + } + + /// Returns true if the memory pool handle is null + pub fn is_null(&self) -> bool { + self.handle.is_null() + } + + /// Gets the raw handle to the memory pool + pub fn handle(&self) -> sys::hipMemPool_t { + self.handle + } +} + +// Implement Send and Sync since MemPool can be safely shared between threads +unsafe impl Send for MemPool {} +unsafe impl Sync for MemPool {} + #[repr(u32)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MemoryCopyKind { diff --git a/src/core/result.rs b/src/core/result.rs index d48904d..c93258c 100644 --- a/src/core/result.rs +++ b/src/core/result.rs @@ -24,6 +24,7 @@ pub enum HipErrorKind { Deinitialized = 4, InvalidDevice = 101, FileNotFound = 301, + NotSupported = 801, Unknown = 999, } @@ -37,6 +38,7 @@ impl HipErrorKind { 4 => HipErrorKind::Deinitialized, 101 => HipErrorKind::InvalidDevice, 301 => HipErrorKind::FileNotFound, + 801 => HipErrorKind::NotSupported, _ => HipErrorKind::Unknown, } }