From 94089e5a6c5a84fe248ed081ac435847f58f70e7 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Tue, 31 Dec 2024 10:15:12 +0100 Subject: [PATCH] 157 add blas call (#158) * add blas_call * ignore unused import warnings * make sys public --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/core/device.rs | 1 + src/core/device_types.rs | 1 + src/core/hip_call.rs | 9 ++- src/core/init.rs | 1 + src/core/mod.rs | 1 + src/core/stream.rs | 1 + src/hipblas/blas_call.rs | 170 +++++++++++++++++++++++++++++++++++++++ src/hipblas/gemm.rs | 4 +- src/hipblas/handle.rs | 4 +- src/hipblas/mod.rs | 3 + src/hipblas/result.rs | 8 +- src/lib.rs | 3 +- 12 files changed, 194 insertions(+), 12 deletions(-) create mode 100644 src/hipblas/blas_call.rs diff --git a/src/core/device.rs b/src/core/device.rs index 9cabea3..c04db61 100644 --- a/src/core/device.rs +++ b/src/core/device.rs @@ -1,3 +1,4 @@ +#[allow(unused_imports)] use super::result::{HipError, HipResult, HipStatus}; use super::{DeviceP2PAttribute, MemPool, PCIBusId}; use crate::result::ResultExt; diff --git a/src/core/device_types.rs b/src/core/device_types.rs index d7bae87..8fb1c62 100644 --- a/src/core/device_types.rs +++ b/src/core/device_types.rs @@ -1,3 +1,4 @@ +#[allow(unused_imports)] use super::result::{HipError, HipResult, HipStatus}; use crate::sys; use std::ffi::CStr; diff --git a/src/core/hip_call.rs b/src/core/hip_call.rs index 41ad169..3b19427 100644 --- a/src/core/hip_call.rs +++ b/src/core/hip_call.rs @@ -1,6 +1,9 @@ -use super::result::{HipResult, HipStatus}; -use crate::result::ResultExt; -use crate::sys; +#[allow(unused_imports)] +use { + super::result::{HipResult, HipStatus}, + crate::result::ResultExt, + crate::sys, +}; #[macro_export] macro_rules! hip_call { diff --git a/src/core/init.rs b/src/core/init.rs index 66b2a7e..fac84a1 100644 --- a/src/core/init.rs +++ b/src/core/init.rs @@ -1,3 +1,4 @@ +#[allow(unused_imports)] use super::result::{HipResult, HipStatus}; pub use crate::result::ResultExt; use crate::sys; diff --git a/src/core/mod.rs b/src/core/mod.rs index 197ae87..1421ae0 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -12,6 +12,7 @@ mod stream; pub use device::*; pub use device_types::*; pub use flags::*; +#[allow(unused_imports)] pub use hip_call::*; pub use init::*; pub use memory::*; diff --git a/src/core/stream.rs b/src/core/stream.rs index 171b546..b5c2d73 100644 --- a/src/core/stream.rs +++ b/src/core/stream.rs @@ -1,3 +1,4 @@ +#[allow(unused_imports)] use super::result::{HipResult, HipStatus}; use crate::result::ResultExt; use crate::sys; diff --git a/src/hipblas/blas_call.rs b/src/hipblas/blas_call.rs new file mode 100644 index 0000000..b824b33 --- /dev/null +++ b/src/hipblas/blas_call.rs @@ -0,0 +1,170 @@ +#[allow(unused_imports)] +use { + super::result::{BlasError, BlasResult}, + crate::result::ResultExt, + crate::sys, +}; + +/// Executes a BLAS Basic Linear Algebra Subprograms) function call and converts the result into a BlasResult. +/// +/// This macro wraps unsafe BLAS function calls and handles error checking by converting +/// the returned status code into a proper Result value. +/// +/// # Arguments +/// +/// * `$call` - The BLAS function call expression to execute +/// +/// # Returns +/// +/// * `BlasResult<()>` - Ok(()) if successful, Err(BlasError) if there was an error +/// +/// # Examples +/// +/// ```ignore +/// // this example will not compile, but give the basic idea of how +/// // to use `blas_call!` +/// +/// use hip_rs::sys; +/// use hip_rs::blas_call; +/// let mut result = 0.0f32; +/// let blas_result = blas_call!( +/// sys::hipblasSasum(handle.handle(), n, x.as_pointer(), 1, &mut result) +/// ); +/// ``` +/// + +#[macro_export] +macro_rules! blas_call { + ($call:expr) => {{ + let code: u32 = unsafe { $call }; + let result: BlasResult<()> = ((), code).to_result(); + result + }}; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{sys, BlasHandle, MemoryPointer}; + + fn setup_test_vector() -> (BlasHandle, MemoryPointer) { + let handle = BlasHandle::new().unwrap(); + + // Create a test vector with known values + let n = 5; + let vec = MemoryPointer::::alloc(n).unwrap(); + + // Initialize vector with test values + let host_data: Vec = vec![1.0, -2.0, 3.0, -4.0, 5.0]; + unsafe { + sys::hipMemcpy( + vec.as_pointer() as *mut std::ffi::c_void, + host_data.as_ptr() as *const std::ffi::c_void, + (n * std::mem::size_of::()) as usize, + sys::hipMemcpyKind_hipMemcpyHostToDevice, + ); + } + + (handle, vec) + } + + #[test] + fn test_isamin() { + let (handle, vec) = setup_test_vector(); + let mut result: i32 = 0; + + let blas_result = blas_call!(sys::hipblasIsamin( + handle.handle(), + 5, // n elements + vec.as_pointer(), + 1, // stride + &mut result, + )); + assert!(blas_result.is_ok()); + assert_eq!(result, 1); // 1.0 has maximum absolute value (1-based indexing) + } + + #[test] + fn test_isamax() { + let (handle, vec) = setup_test_vector(); + let mut result: i32 = 0; + + let blas_result = blas_call!(sys::hipblasIsamax( + handle.handle(), + 5, // n elements + vec.as_pointer(), + 1, // stride + &mut result, + )); + + assert!(blas_result.is_ok()); + assert_eq!(result, 5); // 5.0 has maximum absolute value (1-based indexing) + } + + #[test] + fn test_sasum() { + let (handle, vec) = setup_test_vector(); + let mut result: f32 = 0.0; + + let blas_result = blas_call!(sys::hipblasSasum( + handle.handle(), + 5, // n elements + vec.as_pointer(), + 1, // stride + &mut result, + )); + + assert!(blas_result.is_ok()); + // Expected sum of absolute values: |1.0| + |-2.0| + |3.0| + |-4.0| + |5.0| = 15.0 + assert_eq!(result, 15.0); + } + + #[test] + fn test_invalid_handle() { + let (_, vec) = setup_test_vector(); + let mut result: f32 = 0.0; + + let blas_result = blas_call!(sys::hipblasSasum( + std::ptr::null_mut(), // Invalid handle + 5, + vec.as_pointer(), + 1, + &mut result, + )); + + assert!(blas_result.is_err()); + } + + #[test] + fn test_invalid_pointer() { + let handle = BlasHandle::new().unwrap(); + let mut result: f32 = 0.0; + + let blas_result = blas_call!(sys::hipblasSasum( + handle.handle(), + 5, + std::ptr::null(), // Invalid pointer + 1, + &mut result, + )); + + assert!(blas_result.is_err()); + } + + #[test] + fn test_zero_length() { + let (handle, vec) = setup_test_vector(); + let mut result: f32 = 0.0; + + let blas_result = blas_call!(sys::hipblasSasum( + handle.handle(), + 0, // Zero length + vec.as_pointer(), + 1, + &mut result, + )); + + assert!(blas_result.is_ok()); + assert_eq!(result, 0.0); // Sum should be 0 + } +} diff --git a/src/hipblas/gemm.rs b/src/hipblas/gemm.rs index 7c251a8..f7dc25b 100644 --- a/src/hipblas/gemm.rs +++ b/src/hipblas/gemm.rs @@ -1,4 +1,4 @@ -use super::{BlasHandle, Operation, Result}; +use super::{BlasHandle, BlasResult, Operation}; use crate::result::ResultExt; use crate::Complex32; use crate::{sys, MemoryPointer}; @@ -184,7 +184,7 @@ pub fn gemm( beta: &T, c: &mut MemoryPointer, ldc: i32, -) -> Result<()> { +) -> BlasResult<()> { unsafe { let code = T::hipblas_gemm( handle.handle(), diff --git a/src/hipblas/handle.rs b/src/hipblas/handle.rs index f67e0a2..a85d070 100644 --- a/src/hipblas/handle.rs +++ b/src/hipblas/handle.rs @@ -1,4 +1,4 @@ -use super::Result; +use super::BlasResult; use crate::result::ResultExt; use crate::sys; use std::fmt; @@ -42,7 +42,7 @@ impl BlasHandle { /// /// let handle = BlasHandle::new().unwrap(); /// ``` - pub fn new() -> Result { + pub fn new() -> BlasResult { let mut handle = std::ptr::null_mut(); unsafe { let status = sys::hipblasCreate(&mut handle); diff --git a/src/hipblas/mod.rs b/src/hipblas/mod.rs index 31dc80e..1e34302 100644 --- a/src/hipblas/mod.rs +++ b/src/hipblas/mod.rs @@ -1,8 +1,11 @@ +mod blas_call; mod gemm; mod handle; mod result; mod types; +#[allow(unused_imports)] +pub use blas_call::*; pub use gemm::*; pub use handle::*; pub use result::*; diff --git a/src/hipblas/result.rs b/src/hipblas/result.rs index 98e4d19..9104158 100644 --- a/src/hipblas/result.rs +++ b/src/hipblas/result.rs @@ -89,11 +89,11 @@ impl StatusCode for BlasError { } } -pub type Result = std::result::Result; +pub type BlasResult = std::result::Result; impl ResultExt for (T, u32) { type Value = T; - fn to_result(self) -> Result { + fn to_result(self) -> BlasResult { let (value, status) = self; (value, BlasError::new(status)).to_result() } @@ -151,8 +151,8 @@ mod tests { #[test] fn test_result_ext() { - let success: Result = (42, 0).to_result(); - let error: Result = (42, 1).to_result(); + let success: BlasResult = (42, 0).to_result(); + let error: BlasResult = (42, 1).to_result(); assert!(success.is_ok()); assert!(error.is_err()); diff --git a/src/lib.rs b/src/lib.rs index 7c105ac..9ebb254 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,8 +2,9 @@ mod core; mod hipblas; mod result; -mod sys; +pub mod sys; pub use core::*; pub use hipblas::*; pub use result::*; +pub use sys::*;