Skip to content

Commit

Permalink
157 add blas call (#158)
Browse files Browse the repository at this point in the history
* add blas_call

* ignore unused import warnings

* make sys public

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
smedegaard and github-actions[bot] authored Dec 31, 2024
1 parent 49ce198 commit 94089e5
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/core/device.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[allow(unused_imports)]
use super::result::{HipError, HipResult, HipStatus};
use super::{DeviceP2PAttribute, MemPool, PCIBusId};
use crate::result::ResultExt;
Expand Down
1 change: 1 addition & 0 deletions src/core/device_types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[allow(unused_imports)]
use super::result::{HipError, HipResult, HipStatus};
use crate::sys;
use std::ffi::CStr;
Expand Down
9 changes: 6 additions & 3 deletions src/core/hip_call.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use super::result::{HipResult, HipStatus};
use crate::result::ResultExt;
use crate::sys;
#[allow(unused_imports)]
use {
super::result::{HipResult, HipStatus},
crate::result::ResultExt,
crate::sys,
};

#[macro_export]
macro_rules! hip_call {
Expand Down
1 change: 1 addition & 0 deletions src/core/init.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[allow(unused_imports)]
use super::result::{HipResult, HipStatus};
pub use crate::result::ResultExt;
use crate::sys;
Expand Down
1 change: 1 addition & 0 deletions src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod stream;
pub use device::*;
pub use device_types::*;
pub use flags::*;
#[allow(unused_imports)]
pub use hip_call::*;
pub use init::*;
pub use memory::*;
Expand Down
1 change: 1 addition & 0 deletions src/core/stream.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[allow(unused_imports)]
use super::result::{HipResult, HipStatus};
use crate::result::ResultExt;
use crate::sys;
Expand Down
170 changes: 170 additions & 0 deletions src/hipblas/blas_call.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#[allow(unused_imports)]
use {
super::result::{BlasError, BlasResult},
crate::result::ResultExt,
crate::sys,
};

/// Executes a BLAS Basic Linear Algebra Subprograms) function call and converts the result into a BlasResult.
///
/// This macro wraps unsafe BLAS function calls and handles error checking by converting
/// the returned status code into a proper Result value.
///
/// # Arguments
///
/// * `$call` - The BLAS function call expression to execute
///
/// # Returns
///
/// * `BlasResult<()>` - Ok(()) if successful, Err(BlasError) if there was an error
///
/// # Examples
///
/// ```ignore
/// // this example will not compile, but give the basic idea of how
/// // to use `blas_call!`
///
/// use hip_rs::sys;
/// use hip_rs::blas_call;
/// let mut result = 0.0f32;
/// let blas_result = blas_call!(
/// sys::hipblasSasum(handle.handle(), n, x.as_pointer(), 1, &mut result)
/// );
/// ```
///
#[macro_export]
macro_rules! blas_call {
($call:expr) => {{
let code: u32 = unsafe { $call };
let result: BlasResult<()> = ((), code).to_result();
result
}};
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{sys, BlasHandle, MemoryPointer};

fn setup_test_vector() -> (BlasHandle, MemoryPointer<f32>) {
let handle = BlasHandle::new().unwrap();

// Create a test vector with known values
let n = 5;
let vec = MemoryPointer::<f32>::alloc(n).unwrap();

// Initialize vector with test values
let host_data: Vec<f32> = vec![1.0, -2.0, 3.0, -4.0, 5.0];
unsafe {
sys::hipMemcpy(
vec.as_pointer() as *mut std::ffi::c_void,
host_data.as_ptr() as *const std::ffi::c_void,
(n * std::mem::size_of::<f32>()) as usize,
sys::hipMemcpyKind_hipMemcpyHostToDevice,
);
}

(handle, vec)
}

#[test]
fn test_isamin() {
let (handle, vec) = setup_test_vector();
let mut result: i32 = 0;

let blas_result = blas_call!(sys::hipblasIsamin(
handle.handle(),
5, // n elements
vec.as_pointer(),
1, // stride
&mut result,
));
assert!(blas_result.is_ok());
assert_eq!(result, 1); // 1.0 has maximum absolute value (1-based indexing)
}

#[test]
fn test_isamax() {
let (handle, vec) = setup_test_vector();
let mut result: i32 = 0;

let blas_result = blas_call!(sys::hipblasIsamax(
handle.handle(),
5, // n elements
vec.as_pointer(),
1, // stride
&mut result,
));

assert!(blas_result.is_ok());
assert_eq!(result, 5); // 5.0 has maximum absolute value (1-based indexing)
}

#[test]
fn test_sasum() {
let (handle, vec) = setup_test_vector();
let mut result: f32 = 0.0;

let blas_result = blas_call!(sys::hipblasSasum(
handle.handle(),
5, // n elements
vec.as_pointer(),
1, // stride
&mut result,
));

assert!(blas_result.is_ok());
// Expected sum of absolute values: |1.0| + |-2.0| + |3.0| + |-4.0| + |5.0| = 15.0
assert_eq!(result, 15.0);
}

#[test]
fn test_invalid_handle() {
let (_, vec) = setup_test_vector();
let mut result: f32 = 0.0;

let blas_result = blas_call!(sys::hipblasSasum(
std::ptr::null_mut(), // Invalid handle
5,
vec.as_pointer(),
1,
&mut result,
));

assert!(blas_result.is_err());
}

#[test]
fn test_invalid_pointer() {
let handle = BlasHandle::new().unwrap();
let mut result: f32 = 0.0;

let blas_result = blas_call!(sys::hipblasSasum(
handle.handle(),
5,
std::ptr::null(), // Invalid pointer
1,
&mut result,
));

assert!(blas_result.is_err());
}

#[test]
fn test_zero_length() {
let (handle, vec) = setup_test_vector();
let mut result: f32 = 0.0;

let blas_result = blas_call!(sys::hipblasSasum(
handle.handle(),
0, // Zero length
vec.as_pointer(),
1,
&mut result,
));

assert!(blas_result.is_ok());
assert_eq!(result, 0.0); // Sum should be 0
}
}
4 changes: 2 additions & 2 deletions src/hipblas/gemm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{BlasHandle, Operation, Result};
use super::{BlasHandle, BlasResult, Operation};
use crate::result::ResultExt;
use crate::Complex32;
use crate::{sys, MemoryPointer};
Expand Down Expand Up @@ -184,7 +184,7 @@ pub fn gemm<T: GemmDatatype>(
beta: &T,
c: &mut MemoryPointer<T>,
ldc: i32,
) -> Result<()> {
) -> BlasResult<()> {
unsafe {
let code = T::hipblas_gemm(
handle.handle(),
Expand Down
4 changes: 2 additions & 2 deletions src/hipblas/handle.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::Result;
use super::BlasResult;
use crate::result::ResultExt;
use crate::sys;
use std::fmt;
Expand Down Expand Up @@ -42,7 +42,7 @@ impl BlasHandle {
///
/// let handle = BlasHandle::new().unwrap();
/// ```
pub fn new() -> Result<Self> {
pub fn new() -> BlasResult<Self> {
let mut handle = std::ptr::null_mut();
unsafe {
let status = sys::hipblasCreate(&mut handle);
Expand Down
3 changes: 3 additions & 0 deletions src/hipblas/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
mod blas_call;
mod gemm;
mod handle;
mod result;
mod types;

#[allow(unused_imports)]
pub use blas_call::*;
pub use gemm::*;
pub use handle::*;
pub use result::*;
Expand Down
8 changes: 4 additions & 4 deletions src/hipblas/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ impl StatusCode for BlasError {
}
}

pub type Result<T> = std::result::Result<T, BlasError>;
pub type BlasResult<T> = std::result::Result<T, BlasError>;

impl<T> ResultExt<T, BlasError> for (T, u32) {
type Value = T;
fn to_result(self) -> Result<T> {
fn to_result(self) -> BlasResult<T> {
let (value, status) = self;
(value, BlasError::new(status)).to_result()
}
Expand Down Expand Up @@ -151,8 +151,8 @@ mod tests {

#[test]
fn test_result_ext() {
let success: Result<i32> = (42, 0).to_result();
let error: Result<i32> = (42, 1).to_result();
let success: BlasResult<i32> = (42, 0).to_result();
let error: BlasResult<i32> = (42, 1).to_result();

assert!(success.is_ok());
assert!(error.is_err());
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
mod core;
mod hipblas;
mod result;
mod sys;
pub mod sys;

pub use core::*;
pub use hipblas::*;
pub use result::*;
pub use sys::*;

0 comments on commit 94089e5

Please sign in to comment.