Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

closes #4. hipdevicegetp2pattribute #16

Merged
merged 17 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions src/runtime/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,4 @@
//! HIP Runtime API bindings
mod result;
mod safe;
pub mod sys;

pub use result::{HipError, Result};
pub use safe::*;

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_hip_init() {
initialize().expect("Failed to initialize HIP");
let count = get_device_count().expect("Failed to get device count");
println!("Found {} HIP devices", count);
}
}
mod types;
107 changes: 99 additions & 8 deletions src/runtime/safe.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::result::{HipError, HipErrorKind, HipResult, Result};
use super::sys;
use super::types::{DeviceP2PAttribute, HipError, HipErrorKind, HipResult, Result};
use crate::types::Device;
use semver::Version;
use std::ffi::CStr;
Expand Down Expand Up @@ -45,7 +45,7 @@ pub fn get_device_count() -> Result<i32> {
///
/// # Returns
/// Returns a `Result` containing either:
/// * `Ok(Device)` - The currently active device if one is set
/// * `Ok(Device)` - The currently active device [`crate::Device`] if one is set
/// * `Err(HipError)` - If getting the device failed
///
/// # Errors
Expand All @@ -67,7 +67,7 @@ pub fn get_device() -> Result<Device> {
/// in the current host thread. Other host threads are not affected.
///
/// # Arguments
/// * `device` - The device to make active
/// * `device` - The device [`crate::Device`] to make active
///
/// # Returns
/// * `Ok(())` if the device was successfully made active
Expand All @@ -92,7 +92,7 @@ pub fn set_device(device: Device) -> Result<Device> {
/// supported by the device's architecture.
///
/// # Arguments
/// * `device` - A `Device` instance representing the HIP device to query
/// * `device` - The device [`crate::Device`] to query instance representing the HIP device to query
///
/// # Returns
/// * `Result<Version>` - On success, returns a `Version` struct containing the major and minor version
Expand All @@ -110,7 +110,7 @@ pub fn device_compute_capability(device: Device) -> Result<Version> {
/// Returns the total amount of memory on a HIP device.
///
/// # Arguments
/// * `device` - The device to query
/// * `device` - The device [`crate::Device`] to query
///
/// # Returns
/// * `Result<usize>` - The total memory in bytes if successful
Expand Down Expand Up @@ -167,7 +167,7 @@ pub fn runtime_get_version() -> Result<Version> {
/// Gets the name of a HIP device.
///
/// # Arguments
/// * `device` - The device ID to query
/// * `device` - The device [`crate::Device`] to query
///
/// # Returns
/// * `Result<String>` - The device name if successful
Expand All @@ -192,7 +192,7 @@ pub fn get_device_name(device: Device) -> Result<String> {
/// Gets the UUID bytes for a HIP device.
///
/// # Arguments
/// * `device` - The device to query
/// * `device` - The device [`crate::Device`] to query
///
/// # Returns
/// * `Result<[i8; 16]>` - The UUID as a 16-byte array if successful
Expand All @@ -215,7 +215,7 @@ fn get_device_uuid_bytes(device: Device) -> Result<[i8; 16]> {
/// Retrieves the unique identifier (UUID) for a specified HIP device,
///
/// # Arguments
/// * `device` - The device to query
/// * `device` - The device [`crate::Device`] to query
///
/// # Returns
/// * `Result<Uuid>` - The device UUID if successful
Expand All @@ -232,10 +232,101 @@ pub fn get_device_uuid(device: Device) -> Result<Uuid> {
})
}

/// Retrieves a peer-to-peer attribute value between two HIP devices.
///
/// This function queries the specified peer-to-peer attribute between a source and destination device.
/// The attribute can be used to determine various P2P capabilities and performance characteristics
/// between the two devices.
///
/// # Arguments
/// * `src_device` - Source [`crate::Device`] for P2P attribute query
/// * `dst_device` - Target [`crate::Device`] for P2P attribute query
/// * `attr` - The [`DeviceP2PAttribute`](DeviceP2PAttribute) to query
///
/// # Returns
/// * `Result<i32>` - The attribute value if successful
///
/// # Errors
/// Returns `HipError` if:
/// * Either device ID is invalid
/// * The devices are the same
/// * The runtime is not initialized
/// * Getting the attribute fails
pub fn get_device_p2p_attribute(
attr: DeviceP2PAttribute,
src_device: Device,
dst_device: Device,
) -> Result<i32> {
let mut value = -1;
unsafe {
let code =
sys::hipDeviceGetP2PAttribute(&mut value, attr.into(), src_device.id, dst_device.id);
(value, code).to_result()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_get_device_p2p_attribute() {
let device_0 = Device::new(0);
let device_1 = Device::new(1);

let attributes = vec![
DeviceP2PAttribute::PerformanceRank,
DeviceP2PAttribute::AccessSupported,
DeviceP2PAttribute::NativeAtomicSupported,
DeviceP2PAttribute::HipArrayAccessSupported,
];

for attr in attributes {
let result = get_device_p2p_attribute(attr, device_0, device_1);
assert!(result.is_ok());
let value = result.unwrap();
println!(
"{:?} attribute value between device {} and {}: {}",
attr,
device_0.id(),
device_1.id(),
value
);
}
}

#[test]
fn test_get_device_p2p_attribute_same_device() {
let device = Device::new(0);

let attributes = vec![
DeviceP2PAttribute::PerformanceRank,
DeviceP2PAttribute::AccessSupported,
DeviceP2PAttribute::NativeAtomicSupported,
DeviceP2PAttribute::HipArrayAccessSupported,
];

for attr in attributes {
let result = get_device_p2p_attribute(attr, device, device);
assert!(
result.is_err(),
"expect getting P2P attribute from same device will fail, failed for attribute {:?}",
attr
);
}
}

#[test]
fn test_get_device_p2p_attribute_invalid_device() {
let device = Device::new(0);
let invalid_device = Device::new(99);

let result =
get_device_p2p_attribute(DeviceP2PAttribute::AccessSupported, device, invalid_device);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind, HipErrorKind::InvalidDevice);
}

#[test]
fn test_get_device_uuid_bytes() {
let device = Device::new(0);
Expand Down
46 changes: 46 additions & 0 deletions src/runtime/result.rs → src/runtime/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::sys;
use std::fmt;

/// Success code from HIP runtime
Expand Down Expand Up @@ -95,3 +96,48 @@ impl<T> HipResult for (T, u32) {
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DeviceP2PAttribute {
PerformanceRank,
AccessSupported,
NativeAtomicSupported,
HipArrayAccessSupported,
}

impl From<DeviceP2PAttribute> for u32 {
fn from(attr: DeviceP2PAttribute) -> Self {
match attr {
DeviceP2PAttribute::PerformanceRank => {
sys::hipDeviceP2PAttr_hipDevP2PAttrPerformanceRank
}
DeviceP2PAttribute::AccessSupported => {
sys::hipDeviceP2PAttr_hipDevP2PAttrAccessSupported
}
DeviceP2PAttribute::NativeAtomicSupported => {
sys::hipDeviceP2PAttr_hipDevP2PAttrNativeAtomicSupported
}
DeviceP2PAttribute::HipArrayAccessSupported => {
sys::hipDeviceP2PAttr_hipDevP2PAttrHipArrayAccessSupported
}
}
}
}

impl TryFrom<u32> for DeviceP2PAttribute {
type Error = HipError;

fn try_from(value: sys::hipDeviceP2PAttr) -> Result<Self> {
match value {
sys::hipDeviceP2PAttr_hipDevP2PAttrPerformanceRank => Ok(Self::PerformanceRank),
sys::hipDeviceP2PAttr_hipDevP2PAttrAccessSupported => Ok(Self::AccessSupported),
sys::hipDeviceP2PAttr_hipDevP2PAttrNativeAtomicSupported => {
Ok(Self::NativeAtomicSupported)
}
sys::hipDeviceP2PAttr_hipDevP2PAttrHipArrayAccessSupported => {
Ok(Self::HipArrayAccessSupported)
}
_ => Err(HipError::from_kind(HipErrorKind::InvalidValue)),
}
}
}