diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 4f14487..9613cec 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -1,19 +1,4 @@ //! HIP Runtime API bindings -mod result; mod safe; pub mod sys; - -pub use result::{HipError, Result}; -pub use safe::*; - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_hip_init() { - initialize().expect("Failed to initialize HIP"); - let count = get_device_count().expect("Failed to get device count"); - println!("Found {} HIP devices", count); - } -} +mod types; diff --git a/src/runtime/safe.rs b/src/runtime/safe.rs index 61da732..1e5fbee 100644 --- a/src/runtime/safe.rs +++ b/src/runtime/safe.rs @@ -1,5 +1,5 @@ -use super::result::{HipError, HipErrorKind, HipResult, Result}; use super::sys; +use super::types::{DeviceP2PAttribute, HipError, HipErrorKind, HipResult, Result}; use crate::types::Device; use semver::Version; use std::ffi::CStr; @@ -45,7 +45,7 @@ pub fn get_device_count() -> Result { /// /// # Returns /// Returns a `Result` containing either: -/// * `Ok(Device)` - The currently active device if one is set +/// * `Ok(Device)` - The currently active device [`crate::Device`] if one is set /// * `Err(HipError)` - If getting the device failed /// /// # Errors @@ -67,7 +67,7 @@ pub fn get_device() -> Result { /// in the current host thread. Other host threads are not affected. /// /// # Arguments -/// * `device` - The device to make active +/// * `device` - The device [`crate::Device`] to make active /// /// # Returns /// * `Ok(())` if the device was successfully made active @@ -92,7 +92,7 @@ pub fn set_device(device: Device) -> Result { /// supported by the device's architecture. /// /// # Arguments -/// * `device` - A `Device` instance representing the HIP device to query +/// * `device` - The device [`crate::Device`] to query instance representing the HIP device to query /// /// # Returns /// * `Result` - On success, returns a `Version` struct containing the major and minor version @@ -110,7 +110,7 @@ pub fn device_compute_capability(device: Device) -> Result { /// Returns the total amount of memory on a HIP device. /// /// # Arguments -/// * `device` - The device to query +/// * `device` - The device [`crate::Device`] to query /// /// # Returns /// * `Result` - The total memory in bytes if successful @@ -167,7 +167,7 @@ pub fn runtime_get_version() -> Result { /// Gets the name of a HIP device. /// /// # Arguments -/// * `device` - The device ID to query +/// * `device` - The device [`crate::Device`] to query /// /// # Returns /// * `Result` - The device name if successful @@ -192,7 +192,7 @@ pub fn get_device_name(device: Device) -> Result { /// Gets the UUID bytes for a HIP device. /// /// # Arguments -/// * `device` - The device to query +/// * `device` - The device [`crate::Device`] to query /// /// # Returns /// * `Result<[i8; 16]>` - The UUID as a 16-byte array if successful @@ -215,7 +215,7 @@ fn get_device_uuid_bytes(device: Device) -> Result<[i8; 16]> { /// Retrieves the unique identifier (UUID) for a specified HIP device, /// /// # Arguments -/// * `device` - The device to query +/// * `device` - The device [`crate::Device`] to query /// /// # Returns /// * `Result` - The device UUID if successful @@ -232,10 +232,101 @@ pub fn get_device_uuid(device: Device) -> Result { }) } +/// Retrieves a peer-to-peer attribute value between two HIP devices. +/// +/// This function queries the specified peer-to-peer attribute between a source and destination device. +/// The attribute can be used to determine various P2P capabilities and performance characteristics +/// between the two devices. +/// +/// # Arguments +/// * `src_device` - Source [`crate::Device`] for P2P attribute query +/// * `dst_device` - Target [`crate::Device`] for P2P attribute query +/// * `attr` - The [`DeviceP2PAttribute`](DeviceP2PAttribute) to query +/// +/// # Returns +/// * `Result` - The attribute value if successful +/// +/// # Errors +/// Returns `HipError` if: +/// * Either device ID is invalid +/// * The devices are the same +/// * The runtime is not initialized +/// * Getting the attribute fails +pub fn get_device_p2p_attribute( + attr: DeviceP2PAttribute, + src_device: Device, + dst_device: Device, +) -> Result { + let mut value = -1; + unsafe { + let code = + sys::hipDeviceGetP2PAttribute(&mut value, attr.into(), src_device.id, dst_device.id); + (value, code).to_result() + } +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn test_get_device_p2p_attribute() { + let device_0 = Device::new(0); + let device_1 = Device::new(1); + + let attributes = vec![ + DeviceP2PAttribute::PerformanceRank, + DeviceP2PAttribute::AccessSupported, + DeviceP2PAttribute::NativeAtomicSupported, + DeviceP2PAttribute::HipArrayAccessSupported, + ]; + + for attr in attributes { + let result = get_device_p2p_attribute(attr, device_0, device_1); + assert!(result.is_ok()); + let value = result.unwrap(); + println!( + "{:?} attribute value between device {} and {}: {}", + attr, + device_0.id(), + device_1.id(), + value + ); + } + } + + #[test] + fn test_get_device_p2p_attribute_same_device() { + let device = Device::new(0); + + let attributes = vec![ + DeviceP2PAttribute::PerformanceRank, + DeviceP2PAttribute::AccessSupported, + DeviceP2PAttribute::NativeAtomicSupported, + DeviceP2PAttribute::HipArrayAccessSupported, + ]; + + for attr in attributes { + let result = get_device_p2p_attribute(attr, device, device); + assert!( + result.is_err(), + "expect getting P2P attribute from same device will fail, failed for attribute {:?}", + attr + ); + } + } + + #[test] + fn test_get_device_p2p_attribute_invalid_device() { + let device = Device::new(0); + let invalid_device = Device::new(99); + + let result = + get_device_p2p_attribute(DeviceP2PAttribute::AccessSupported, device, invalid_device); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().kind, HipErrorKind::InvalidDevice); + } + #[test] fn test_get_device_uuid_bytes() { let device = Device::new(0); diff --git a/src/runtime/result.rs b/src/runtime/types.rs similarity index 59% rename from src/runtime/result.rs rename to src/runtime/types.rs index 84adbd8..a09706d 100644 --- a/src/runtime/result.rs +++ b/src/runtime/types.rs @@ -1,3 +1,4 @@ +use super::sys; use std::fmt; /// Success code from HIP runtime @@ -95,3 +96,48 @@ impl HipResult for (T, u32) { } } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DeviceP2PAttribute { + PerformanceRank, + AccessSupported, + NativeAtomicSupported, + HipArrayAccessSupported, +} + +impl From for u32 { + fn from(attr: DeviceP2PAttribute) -> Self { + match attr { + DeviceP2PAttribute::PerformanceRank => { + sys::hipDeviceP2PAttr_hipDevP2PAttrPerformanceRank + } + DeviceP2PAttribute::AccessSupported => { + sys::hipDeviceP2PAttr_hipDevP2PAttrAccessSupported + } + DeviceP2PAttribute::NativeAtomicSupported => { + sys::hipDeviceP2PAttr_hipDevP2PAttrNativeAtomicSupported + } + DeviceP2PAttribute::HipArrayAccessSupported => { + sys::hipDeviceP2PAttr_hipDevP2PAttrHipArrayAccessSupported + } + } + } +} + +impl TryFrom for DeviceP2PAttribute { + type Error = HipError; + + fn try_from(value: sys::hipDeviceP2PAttr) -> Result { + match value { + sys::hipDeviceP2PAttr_hipDevP2PAttrPerformanceRank => Ok(Self::PerformanceRank), + sys::hipDeviceP2PAttr_hipDevP2PAttrAccessSupported => Ok(Self::AccessSupported), + sys::hipDeviceP2PAttr_hipDevP2PAttrNativeAtomicSupported => { + Ok(Self::NativeAtomicSupported) + } + sys::hipDeviceP2PAttr_hipDevP2PAttrHipArrayAccessSupported => { + Ok(Self::HipArrayAccessSupported) + } + _ => Err(HipError::from_kind(HipErrorKind::InvalidValue)), + } + } +}