diff --git a/src/core/device.rs b/src/core/device.rs index 0264a21..0ef9adf 100644 --- a/src/core/device.rs +++ b/src/core/device.rs @@ -1,179 +1,201 @@ use super::sys; -use super::{Device, DeviceP2PAttribute, HipError, HipErrorKind, HipResult, PCIBusId, Result}; +use super::{DeviceP2PAttribute, HipError, HipErrorKind, HipResult, PCIBusId, Result}; use semver::Version; use std::ffi::CStr; use std::i32; use uuid::Uuid; -/// Get the number of available HIP devices. -/// -/// # Returns -/// * `Result` - The number of devices if successful -/// -/// # Errors -/// Returns `HipError` if: -/// * The runtime is not initialized (`HipErrorKind::NotInitialized`) -/// * The operation fails for other reasons -pub fn get_device_count() -> Result { - unsafe { - let mut count = 0; - let code = sys::hipGetDeviceCount(&mut count); - (count, code).to_result() - } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Device { + pub(crate) id: i32, } -/// Gets the currently active HIP device. -/// -/// # Returns -/// Returns a `Result` containing either: -/// * `Ok(Device)` - The currently active device [`crate::Device`] if one is set -/// * `Err(HipError)` - If getting the device failed -/// -/// # Errors -/// Returns `HipError` if: -/// * No device is currently active -/// * HIP runtime is not initialized -/// * There was an error accessing device information -pub fn get_device() -> Result { - unsafe { - let mut device_id: i32 = -1; - let code = sys::hipGetDevice(&mut device_id); - (Device::new(device_id), code).to_result() +impl Device { + /// Creates a new Device handle representing a HIP device. + /// + /// # Arguments + /// * `id` - The device ID to associate with this Device instance + /// + /// # Returns + /// A new `Device` instance initialized with the provided ID + pub fn new(id: i32) -> Self { + Device { id } } -} -/// Sets the active HIP device for the current host thread. -/// -/// This function makes the specified device active for all subsequent HIP operations -/// in the current host thread. Other host threads are not affected. -/// -/// # Arguments -/// * `device` - The device [`crate::Device`] to make active -/// -/// # Returns -/// * `Ok(())` if the device was successfully made active -/// * `Err(HipError)` if the operation failed -/// -/// # Errors -/// Returns `HipError` if: -/// * The device ID is invalid (greater than or equal to device count) -/// * The HIP runtime is not initialized -/// * The specified device has encountered a previous error and is in a broken state -pub fn set_device(device: Device) -> Result { - unsafe { - let code = sys::hipSetDevice(device.id); - (device, code).to_result() + /// Returns the raw HIP device ID. + /// + /// Gets the 'ordinal' numeric identifier that identifies this HIP device. + /// The ID is assigned by the HIP runtime and matches the index when enumerating devices. + /// + /// # Returns + /// * `i32` - The device ID number + pub fn id(&self) -> i32 { + self.id } -} -/// Gets the compute capability version of a HIP device. -/// -/// This function retrieves the major and minor version numbers that specify the compute capability -/// of the given HIP device. Compute capability indicates the technical specifications and features -/// supported by the device's architecture. -/// -/// # Arguments -/// * `device` - The device [`crate::Device`] to query instance representing the HIP device to query -/// -/// # Returns -/// * `Result` - On success, returns a `Version` struct containing the major and minor version -/// numbers of the device's compute capability. On failure, returns an error indicating what went wrong. -pub fn device_compute_capability(device: Device) -> Result { - unsafe { - let mut major: i32 = -1; - let mut minor: i32 = -1; - let code = sys::hipDeviceComputeCapability(&mut major, &mut minor, device.id); - let version = Version::new(major as u64, minor as u64, 0); - (version, code).to_result() + /// Gets the compute capability version of the HIP device. + /// + /// This function retrieves the major and minor version numbers that specify the compute capability + /// of the given HIP device. Compute capability indicates the technical specifications and features + /// supported by the device's architecture. + /// + /// # Returns + /// * `Result` - On success, returns a `Version` struct containing the major and minor version + /// numbers of the device's compute capability. On failure, returns an error indicating what went wrong. + pub fn device_compute_capability(&self) -> Result { + unsafe { + let mut major: i32 = -1; + let mut minor: i32 = -1; + let code = sys::hipDeviceComputeCapability(&mut major, &mut minor, self.id); + let version = Version::new(major as u64, minor as u64, 0); + (version, code).to_result() + } } -} -/// Returns the total amount of memory on a HIP device. -/// -/// # Arguments -/// * `device` - The device [`crate::Device`] to query -/// -/// # Returns -/// * `Result` - The total memory in bytes if successful -/// -/// # Errors -/// Returns `HipError` if: -/// * The device is invalid -/// * The runtime is not initialized -pub fn device_total_mem(device: Device) -> Result { - unsafe { - let mut size: usize = 0; - let code = sys::hipDeviceTotalMem(&mut size, device.id); - (size, code).to_result() + /// Returns the total amount of memory on the device. + /// + /// # Returns + /// * `Result` - The total memory in bytes if successful + /// + /// # Errors + /// Returns `HipError` if: + /// * The device is invalid + /// * The runtime is not initialized + pub fn device_total_mem(&self) -> Result { + unsafe { + let mut size: usize = 0; + let code = sys::hipDeviceTotalMem(&mut size, self.id); + (size, code).to_result() + } } -} -/// Gets the name of a HIP device. -/// -/// # Arguments -/// * `device` - The device [`crate::Device`] to query -/// -/// # Returns -/// * `Result` - The device name if successful -/// -/// # Errors -/// Returns `HipError` if: -/// * The device ID is invalid -/// * There was an error retrieving the device name -/// * The name string could not be converted to valid UTF-8 -pub fn get_device_name(device: Device) -> Result { - const buffer_size: usize = 64; - let mut buffer = vec![0i8; buffer_size]; + /// Gets the name of the device. + /// + /// + /// # Returns + /// * `Result` - The device name if successful + /// + /// # Errors + /// Returns `HipError` if: + /// * The device ID is invalid + /// * There was an error retrieving the device name + /// * The name string could not be converted to valid UTF-8 + pub fn get_device_name(&self) -> Result { + const buffer_size: usize = 64; + let mut buffer = vec![0i8; buffer_size]; + + unsafe { + let code = sys::hipDeviceGetName(buffer.as_mut_ptr(), buffer.len() as i32, self.id); + // Convert the C string to a Rust String + let c_str = CStr::from_ptr(buffer.as_ptr()); + (c_str.to_string_lossy().into_owned(), code).to_result() + } + } - unsafe { - let code = sys::hipDeviceGetName(buffer.as_mut_ptr(), buffer.len() as i32, device.id); - // Convert the C string to a Rust String - let c_str = CStr::from_ptr(buffer.as_ptr()); - (c_str.to_string_lossy().into_owned(), code).to_result() + /// Gets the UUID bytes for a HIP device. + /// + /// # Arguments + /// * `device` - The device [`crate::Device`] to query + /// + /// # Returns + /// * `Result<[i8; 16]>` - The UUID as a 16-byte array if successful + /// + /// # Errors + /// Returns `HipError` if: + /// * The device is invalid + /// * The runtime is not initialized + /// * There was an error retrieving the UUID + fn get_device_uuid_bytes(&self) -> Result<[i8; 16]> { + let mut hip_bytes = sys::hipUUID_t { bytes: [0; 16] }; + unsafe { + let code = sys::hipDeviceGetUuid(&mut hip_bytes, self.id); + (hip_bytes.bytes, code).to_result() + } + } + + /// Gets the UUID for a HIP device. + /// + /// Retrieves the unique identifier (UUID) for a specified HIP device, + /// + /// # Arguments + /// * `device` - The device [`crate::Device`] to query + /// + /// # Returns + /// * `Result` - The device UUID if successful + /// + /// # Errors + /// Returns `HipError` if: + /// * The device is invalid + /// * The runtime is not initialized + /// * There was an error retrieving the UUID + pub fn get_device_uuid(&self) -> Result { + Self::get_device_uuid_bytes(self).map(|bytes| { + let uuid_bytes: [u8; 16] = bytes.map(|b| b as u8); + Uuid::from_bytes(uuid_bytes) + }) + } + + /// Gets the PCI bus ID string for a HIP device. + /// + /// # Arguments + /// * `device` - The device [`crate::Device`] to query + /// + /// # Returns + /// * `Result` - The PCI bus ID string if successful + /// + /// # Errors + /// Returns `HipError` if: + /// * The device is invalid + /// * The runtime is not initialized + /// * There was an error retrieving the PCI bus ID + pub fn get_device_pci_bus_id(&self) -> Result { + let mut pci_bus_id = PCIBusId::new(); + + unsafe { + let code = + sys::hipDeviceGetPCIBusId(pci_bus_id.as_mut_ptr(), pci_bus_id.len(), self.id); + (pci_bus_id, code).to_result() + } } } -/// Gets the UUID bytes for a HIP device. +/// Free Functions + +// Synchronizes the current device by waiting for all active streams to complete. /// -/// # Arguments -/// * `device` - The device [`crate::Device`] to query +/// This function blocks the host thread until all commands in all streams on the +/// current device have completed. This is a global synchronization point. /// /// # Returns -/// * `Result<[i8; 16]>` - The UUID as a 16-byte array if successful +/// * `Ok(())` if synchronization was successful +/// * `Err(HipError)` if the operation failed /// /// # Errors /// Returns `HipError` if: -/// * The device is invalid -/// * The runtime is not initialized -/// * There was an error retrieving the UUID -fn get_device_uuid_bytes(device: Device) -> Result<[i8; 16]> { - let mut hip_bytes = sys::hipUUID_t { bytes: [0; 16] }; +/// * No device is currently active +/// * The HIP runtime is not initialized +pub fn synchronize() -> Result<()> { unsafe { - let code = sys::hipDeviceGetUuid(&mut hip_bytes, device.id); - (hip_bytes.bytes, code).to_result() + let code = sys::hipDeviceSynchronize(); + ((), code).to_result() } } -/// Gets the UUID for a HIP device. -/// -/// Retrieves the unique identifier (UUID) for a specified HIP device, -/// -/// # Arguments -/// * `device` - The device [`crate::Device`] to query +/// Get the number of available HIP devices. /// /// # Returns -/// * `Result` - The device UUID if successful +/// * `Result` - The number of devices if successful /// /// # Errors /// Returns `HipError` if: -/// * The device is invalid -/// * The runtime is not initialized -/// * There was an error retrieving the UUID -pub fn get_device_uuid(device: Device) -> Result { - get_device_uuid_bytes(device).map(|bytes| { - let uuid_bytes: [u8; 16] = bytes.map(|b| b as u8); - Uuid::from_bytes(uuid_bytes) - }) +/// * The runtime is not initialized (`HipErrorKind::NotInitialized`) +/// * The operation fails for other reasons +pub fn get_device_count() -> Result { + unsafe { + let mut count = 0; + let code = sys::hipGetDeviceCount(&mut count); + (count, code).to_result() + } } /// Retrieves a peer-to-peer attribute value between two HIP devices. @@ -209,25 +231,47 @@ pub fn get_device_p2p_attribute( } } -/// Gets the PCI bus ID string for a HIP device. +/// Gets the currently active HIP device. +/// +/// # Returns +/// Returns a `Result` containing either: +/// * `Ok(Device)` - The currently active device [`crate::Device`] if one is set +/// * `Err(HipError)` - If getting the device failed +/// +/// # Errors +/// Returns `HipError` if: +/// * No device is currently active +/// * HIP runtime is not initialized +/// * There was an error accessing device information +pub fn get_device() -> Result { + unsafe { + let mut device_id: i32 = -1; + let code = sys::hipGetDevice(&mut device_id); + (Device::new(device_id), code).to_result() + } +} + +/// Sets the active HIP device for the current host thread. +/// +/// This function makes the specified device active for all subsequent HIP operations +/// in the current host thread. Other host threads are not affected. /// /// # Arguments -/// * `device` - The device [`crate::Device`] to query +/// * `device` - The device [`crate::Device`] to make active /// /// # Returns -/// * `Result` - The PCI bus ID string if successful +/// * `Ok(())` if the device was successfully made active +/// * `Err(HipError)` if the operation failed /// /// # Errors /// Returns `HipError` if: -/// * The device is invalid -/// * The runtime is not initialized -/// * There was an error retrieving the PCI bus ID -pub fn get_device_pci_bus_id(device: Device) -> Result { - let mut pci_bus_id = PCIBusId::new(); - +/// * The device ID is invalid (greater than or equal to device count) +/// * The HIP runtime is not initialized +/// * The specified device has encountered a previous error and is in a broken state +pub fn set_device(device: Device) -> Result { unsafe { - let code = sys::hipDeviceGetPCIBusId(pci_bus_id.as_mut_ptr(), pci_bus_id.len(), device.id); - (pci_bus_id, code).to_result() + let code = sys::hipSetDevice(device.id); + (device, code).to_result() } } @@ -258,13 +302,9 @@ mod tests { #[test] fn test_get_device_by_pci_bus_id() { - // we are relying on `get_device_pci_bus_id()` working as intended to test this function. - // TODO: consider mocking to avoid test dependencies - // First get a valid PCI bus ID from an existing device let device = Device::new(0); - let pci_id = get_device_pci_bus_id(device).unwrap(); + let pci_id = device.get_device_pci_bus_id().unwrap(); - // Test getting device by that PCI ID let result = get_device_by_pci_bus_id(pci_id); assert!(result.is_ok()); assert_eq!(result.unwrap().id(), device.id()); @@ -272,7 +312,7 @@ mod tests { #[test] fn test_get_device_by_invalid_pci_bus_id() { - let invalid_pci_id = PCIBusId::new(); // invalid PCI ID, only contains `0`'s + let invalid_pci_id = PCIBusId::new(); let result = get_device_by_pci_bus_id(invalid_pci_id); assert!(result.is_err()); } @@ -280,7 +320,7 @@ mod tests { #[test] fn test_get_device_pci_bus_id() { let device = Device::new(0); - let result = get_device_pci_bus_id(device); + let result = device.get_device_pci_bus_id(); assert!(result.is_ok()); let pci_id = result.unwrap(); println!("Device PCI Bus ID: {:?}", pci_id); @@ -289,65 +329,7 @@ mod tests { #[test] fn test_get_device_pci_bus_id_invalid_device() { let invalid_device = Device::new(99); - let result = get_device_pci_bus_id(invalid_device); - assert!(result.is_err()); - assert_eq!(result.unwrap_err().kind, HipErrorKind::InvalidDevice); - } - - #[test] - fn test_get_device_p2p_attribute() { - let device_0 = Device::new(0); - let device_1 = Device::new(1); - - let attributes = vec![ - DeviceP2PAttribute::PerformanceRank, - DeviceP2PAttribute::AccessSupported, - DeviceP2PAttribute::NativeAtomicSupported, - DeviceP2PAttribute::HipArrayAccessSupported, - ]; - - for attr in attributes { - let result = get_device_p2p_attribute(attr, device_0, device_1); - assert!(result.is_ok()); - let value = result.unwrap(); - println!( - "{:?} attribute value between device {} and {}: {}", - attr, - device_0.id(), - device_1.id(), - value - ); - } - } - - #[test] - fn test_get_device_p2p_attribute_same_device() { - let device = Device::new(0); - - let attributes = vec![ - DeviceP2PAttribute::PerformanceRank, - DeviceP2PAttribute::AccessSupported, - DeviceP2PAttribute::NativeAtomicSupported, - DeviceP2PAttribute::HipArrayAccessSupported, - ]; - - for attr in attributes { - let result = get_device_p2p_attribute(attr, device, device); - assert!( - result.is_err(), - "expect getting P2P attribute from same device will fail, failed for attribute {:?}", - attr - ); - } - } - - #[test] - fn test_get_device_p2p_attribute_invalid_device() { - let device = Device::new(0); - let invalid_device = Device::new(99); - - let result = - get_device_p2p_attribute(DeviceP2PAttribute::AccessSupported, device, invalid_device); + let result = invalid_device.get_device_pci_bus_id(); assert!(result.is_err()); assert_eq!(result.unwrap_err().kind, HipErrorKind::InvalidDevice); } @@ -355,7 +337,7 @@ mod tests { #[test] fn test_get_device_uuid_bytes() { let device = Device::new(0); - let result = get_device_uuid_bytes(device); + let result = device.get_device_uuid_bytes(); assert!(result.is_ok()); let uuid_bytes = result.unwrap(); assert_eq!(uuid_bytes.len(), 16); @@ -365,7 +347,7 @@ mod tests { #[test] fn test_get_device_uuid() { let device = Device::new(0); - let result = get_device_uuid(device); + let result = device.get_device_uuid(); assert!(result.is_ok()); let uuid = result.unwrap(); println!("Device UUID: {}", uuid); @@ -374,7 +356,7 @@ mod tests { #[test] fn test_get_device_name() { let device = Device::new(0); - let result = get_device_name(device); + let result = device.get_device_name(); assert!(result.is_ok()); let name = result.unwrap(); println!("Device name: {}", name); @@ -383,7 +365,7 @@ mod tests { #[test] fn test_device_total_mem() { let device = Device::new(0); - let result = device_total_mem(device); + let result = device.device_total_mem(); assert!(result.is_ok()); let size = result.unwrap(); assert!(size > 0); @@ -393,16 +375,16 @@ mod tests { #[test] fn test_get_device_compute_capability() { let device = Device::new(0); - let result = device_compute_capability(device); + let result = device.device_compute_capability(); assert!(result.is_ok()); let version = result.unwrap(); assert!(version.major > 0); println!("Compute Capability: {}.{}", version.major, version.minor); } + // These tests remain unchanged as they test free functions #[test] fn test_get_device_count() { - // Test success case let result = get_device_count(); assert!(result.is_ok()); let count = result.unwrap(); @@ -412,26 +394,23 @@ mod tests { #[test] fn test_get_device() { - // Test success case let result = get_device(); assert!(result.is_ok()); let device = result.unwrap(); - println!("Device {} is currently active", device.id); - assert_eq!(device.id, 0); + println!("Device {} is currently active", device.id()); + assert_eq!(device.id(), 0); } #[test] fn test_set_device() { - // Test success case with valid device - let device = Device::new(1); + let device = Device::new(0); let result = set_device(device); assert!(result.is_ok()); - assert_eq!(result.unwrap().id(), 1) + assert_eq!(result.unwrap().id(), 0) } #[test] fn test_set_invalid_device() { - // Test error case with invalid device let invalid_device = Device::new(99); let result = set_device(invalid_device); assert!(result.is_err()); diff --git a/src/core/device_type.rs b/src/core/device_types.rs similarity index 91% rename from src/core/device_type.rs rename to src/core/device_types.rs index 2502940..8fe7987 100644 --- a/src/core/device_type.rs +++ b/src/core/device_types.rs @@ -1,67 +1,7 @@ -use super::result::{HipError, HipErrorKind, Result}; -use crate::sys; +use super::sys; +use super::{HipError, HipErrorKind, HipResult, Result}; use std::ffi::CStr; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct Device { - pub(crate) id: i32, -} - -impl Device { - /// Create a new Device handle - pub fn new(id: i32) -> Self { - Device { id } - } - - /// Get the raw device ID - pub fn id(&self) -> i32 { - self.id - } -} -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum DeviceP2PAttribute { - PerformanceRank, - AccessSupported, - NativeAtomicSupported, - HipArrayAccessSupported, -} - -impl From for u32 { - fn from(attr: DeviceP2PAttribute) -> Self { - match attr { - DeviceP2PAttribute::PerformanceRank => { - sys::hipDeviceP2PAttr_hipDevP2PAttrPerformanceRank - } - DeviceP2PAttribute::AccessSupported => { - sys::hipDeviceP2PAttr_hipDevP2PAttrAccessSupported - } - DeviceP2PAttribute::NativeAtomicSupported => { - sys::hipDeviceP2PAttr_hipDevP2PAttrNativeAtomicSupported - } - DeviceP2PAttribute::HipArrayAccessSupported => { - sys::hipDeviceP2PAttr_hipDevP2PAttrHipArrayAccessSupported - } - } - } -} - -impl TryFrom for DeviceP2PAttribute { - type Error = HipError; - - fn try_from(value: sys::hipDeviceP2PAttr) -> Result { - match value { - sys::hipDeviceP2PAttr_hipDevP2PAttrPerformanceRank => Ok(Self::PerformanceRank), - sys::hipDeviceP2PAttr_hipDevP2PAttrAccessSupported => Ok(Self::AccessSupported), - sys::hipDeviceP2PAttr_hipDevP2PAttrNativeAtomicSupported => { - Ok(Self::NativeAtomicSupported) - } - sys::hipDeviceP2PAttr_hipDevP2PAttrHipArrayAccessSupported => { - Ok(Self::HipArrayAccessSupported) - } - _ => Err(HipError::from_kind(HipErrorKind::InvalidValue)), - } - } -} +use std::i32; pub unsafe trait UnsafeToString { unsafe fn to_string(&self) -> String; @@ -116,6 +56,51 @@ impl PCIBusId { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DeviceP2PAttribute { + PerformanceRank, + AccessSupported, + NativeAtomicSupported, + HipArrayAccessSupported, +} + +impl From for u32 { + fn from(attr: DeviceP2PAttribute) -> Self { + match attr { + DeviceP2PAttribute::PerformanceRank => { + sys::hipDeviceP2PAttr_hipDevP2PAttrPerformanceRank + } + DeviceP2PAttribute::AccessSupported => { + sys::hipDeviceP2PAttr_hipDevP2PAttrAccessSupported + } + DeviceP2PAttribute::NativeAtomicSupported => { + sys::hipDeviceP2PAttr_hipDevP2PAttrNativeAtomicSupported + } + DeviceP2PAttribute::HipArrayAccessSupported => { + sys::hipDeviceP2PAttr_hipDevP2PAttrHipArrayAccessSupported + } + } + } +} + +impl TryFrom for DeviceP2PAttribute { + type Error = HipError; + + fn try_from(value: sys::hipDeviceP2PAttr) -> Result { + match value { + sys::hipDeviceP2PAttr_hipDevP2PAttrPerformanceRank => Ok(Self::PerformanceRank), + sys::hipDeviceP2PAttr_hipDevP2PAttrAccessSupported => Ok(Self::AccessSupported), + sys::hipDeviceP2PAttr_hipDevP2PAttrNativeAtomicSupported => { + Ok(Self::NativeAtomicSupported) + } + sys::hipDeviceP2PAttr_hipDevP2PAttrHipArrayAccessSupported => { + Ok(Self::HipArrayAccessSupported) + } + _ => Err(HipError::from_kind(HipErrorKind::InvalidValue)), + } + } +} + /// Unsafe implementation for converting PCIBusId to a String. unsafe impl UnsafeToString for PCIBusId { /// Converts the internal buffer to a String. diff --git a/src/core/init.rs b/src/core/init.rs index 3fcc400..78c6496 100644 --- a/src/core/init.rs +++ b/src/core/init.rs @@ -1,5 +1,5 @@ use super::sys; -use crate::{Device, HipErrorKind, HipResult, Result}; +use crate::{HipResult, Result}; use semver::Version; use std::i32; diff --git a/src/core/memory_type.rs b/src/core/memory.rs similarity index 99% rename from src/core/memory_type.rs rename to src/core/memory.rs index 8f9f6f9..09e09ed 100644 --- a/src/core/memory_type.rs +++ b/src/core/memory.rs @@ -245,8 +245,7 @@ impl TryFrom for MemoryCopyKind { #[cfg(test)] mod tests { - use crate::Device; - + // use crate::Device; use super::*; use std::thread::sleep; use std::time::Duration; diff --git a/src/core/mod.rs b/src/core/mod.rs index a784437..41ed661 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,17 +1,17 @@ mod device; -mod device_type; +mod device_types; mod flags; mod init; -mod memory_type; +mod memory; mod result; mod stream; pub mod sys; // Re-export core functionality pub use device::*; -pub use device_type::*; +pub use device_types::*; pub use flags::*; pub use init::*; -pub use memory_type::*; +pub use memory::*; pub use result::*; pub use stream::*; diff --git a/src/core/stream.rs b/src/core/stream.rs index 10b2c20..9c60261 100644 --- a/src/core/stream.rs +++ b/src/core/stream.rs @@ -1,4 +1,4 @@ -use crate::{sys, HipErrorKind, HipResult, Result}; +use crate::{sys, HipResult, Result}; /// A handle to a HIP stream that executes commands in order. #[derive(Debug)]