From ce0d3d3fd8713feb71558b8ba9061a31edc70f07 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 19:44:22 +0100 Subject: [PATCH 01/11] chore: release v0.1.2 (#129) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- CHANGELOG.md | 14 ++++++++++++++ Cargo.toml | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bf621bf..e03cad6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.0](https://github.com/smedegaard/hip-rs/compare/v0.1.1...v0.2.0) - 2024-12-09 + +### Other + +- skip binding generationon docs.rs ([#139](https://github.com/smedegaard/hip-rs/pull/139)) +- add MemPool ([#138](https://github.com/smedegaard/hip-rs/pull/138)) +- add stream::query_stream(). closes 134 ([#135](https://github.com/smedegaard/hip-rs/pull/135)) +- closes 26. add Device::get_default_mem_pool() ([#133](https://github.com/smedegaard/hip-rs/pull/133)) +- add synchronize() ([#132](https://github.com/smedegaard/hip-rs/pull/132)) +- only run CI on PR ([#131](https://github.com/smedegaard/hip-rs/pull/131)) +- Update release.yaml ([#130](https://github.com/smedegaard/hip-rs/pull/130)) +- closes 112. add core::stream ([#128](https://github.com/smedegaard/hip-rs/pull/128)) +- rename runtime -> core ([#127](https://github.com/smedegaard/hip-rs/pull/127)) + ## [0.1.1](https://github.com/smedegaard/hip-rs/compare/v0.1.0...v0.1.1) - 2024-12-02 ### Other diff --git a/Cargo.toml b/Cargo.toml index fa99ad5..937e352 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hip-rs" -version = "0.1.1" +version = "0.2.0" edition = "2021" authors = ["Anders Smedegaard Pedersen "] description = "A Rust wrapper for AMD's Heterogeneous-computing Interface for Portability (HIP), used for GPU interop." From 6225eebecf3c4f331c5f4ccd6cfcf9c63bee070e Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Sat, 21 Dec 2024 20:31:54 +0100 Subject: [PATCH 02/11] add hip_call() (#143) --- src/core/hip_call.rs | 35 +++++++++++++++++++++++++++++++++++ src/core/mod.rs | 2 ++ 2 files changed, 37 insertions(+) create mode 100644 src/core/hip_call.rs diff --git a/src/core/hip_call.rs b/src/core/hip_call.rs new file mode 100644 index 0000000..5bb0354 --- /dev/null +++ b/src/core/hip_call.rs @@ -0,0 +1,35 @@ +use super::{sys, HipResult, Result}; + +#[macro_export] +macro_rules! hip_call { + ($call:expr) => {{ + let code = unsafe { $call }; + ((), code).to_result() + }}; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hip_call_simple() { + let result = hip_call!(sys::hipDeviceSynchronize()); + assert!(result.is_ok()); + } + + #[test] + fn test_hip_call_with_value() { + let mut count = 0; + let result = hip_call!(sys::hipGetDeviceCount(&mut count)); + assert!(result.is_ok()); + assert!(count > 0); + } + + #[test] + fn test_hip_call_error() { + // Call with invalid device ID should return error + let result = hip_call!(sys::hipSetDevice(99)); + assert!(result.is_err()); + } +} diff --git a/src/core/mod.rs b/src/core/mod.rs index 41ed661..2e30cac 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,6 +1,7 @@ mod device; mod device_types; mod flags; +mod hip_call; mod init; mod memory; mod result; @@ -11,6 +12,7 @@ pub mod sys; pub use device::*; pub use device_types::*; pub use flags::*; +pub use hip_call::*; pub use init::*; pub use memory::*; pub use result::*; From 0314da4f5e7d2b71a0ecef1acec78d74b4a891ae Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Mon, 23 Dec 2024 12:48:39 +0100 Subject: [PATCH 03/11] add docs step to release job (#145) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .github/workflows/release.yaml | 99 +++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 42 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 9d09df0..cdc9b7d 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -1,50 +1,65 @@ name: Release permissions: - pull-requests: write - contents: write + pull-requests: write + contents: write on: - push: - branches: - - release + push: + branches: + - release jobs: - # Release unpublished packages. - release-plz-release: - name: Release-plz release - runs-on: self-hosted - steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Run release-plz - uses: release-plz/action@v0.5 - with: - command: release - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - CARGO_REGISTRY_TOKEN: ${{ secrets.CRATES_TOKEN }} + # Release unpublished packages. + release-plz-release: + name: Release-plz release + runs-on: self-hosted + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Run release-plz + uses: release-plz/action@v0.5 + with: + command: release + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + CARGO_REGISTRY_TOKEN: ${{ secrets.CRATES_TOKEN }} - # Create a PR with the new versions and changelog, preparing the next release. - release-plz-pr: - name: Release-plz PR - runs-on: self-hosted - concurrency: - group: release-plz-${{ github.ref }} - cancel-in-progress: false - steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Install Rust toolchain - uses: dtolnay/rust-toolchain@stable - - name: Run release-plz - uses: release-plz/action@v0.5 - with: - command: release-pr - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - CARGO_REGISTRY_TOKEN: ${{ secrets.CRATES_TOKEN }} + # Create a PR with the new versions and changelog, preparing the next release. + release-plz-pr: + name: Release-plz PR + runs-on: self-hosted + concurrency: + group: release-plz-${{ github.ref }} + cancel-in-progress: false + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + - name: Run release-plz + uses: release-plz/action@v0.5 + with: + command: release-pr + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + CARGO_REGISTRY_TOKEN: ${{ secrets.CRATES_TOKEN }} + + docs: + runs-on: self-hosted + steps: + - uses: actions/checkout@v3 + - name: Generate Docs + run: cargo doc --no-deps --document-private-items + - name: Add index.html + run: | + echo '' > target/doc/index.html + - name: Deploy + uses: peaceiris/actions-gh-pages@v3 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: ./target/doc From b11988ec6aa85bed706ee018e61d3121d52e7a37 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Mon, 23 Dec 2024 13:09:22 +0100 Subject: [PATCH 04/11] add docs step to release job (#146) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> From f10ebb1b279134506bdaf3d17b281de2618aa65c Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Mon, 23 Dec 2024 13:24:23 +0100 Subject: [PATCH 05/11] 144 add GitHub pages (#148) * add docs step to release job * fix path --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .github/workflows/release.yaml | 50 ++++++++++++++++++++++++++-------- Cargo.toml | 2 +- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index cdc9b7d..a532f2b 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -39,27 +39,53 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 + ref: release - name: Install Rust toolchain uses: dtolnay/rust-toolchain@stable - name: Run release-plz uses: release-plz/action@v0.5 with: + ref: release command: release-pr env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} CARGO_REGISTRY_TOKEN: ${{ secrets.CRATES_TOKEN }} - docs: - runs-on: self-hosted + build_docs: + name: Build Docs + runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Generate Docs - run: cargo doc --no-deps --document-private-items - - name: Add index.html - run: | - echo '' > target/doc/index.html - - name: Deploy - uses: peaceiris/actions-gh-pages@v3 + - name: Checkout repository + uses: actions/checkout@v4 with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: ./target/doc + ref: release + - name: Setup Rust + uses: dtolnay/rust-toolchain@stable + - name: Configure cache + uses: Swatinem/rust-cache@v2 + - name: Setup pages + id: pages + uses: actions/configure-pages@v5 + - name: Clean docs folder + run: cargo clean --doc + - name: Build docs + run: cargo doc --no-deps + - name: Add redirect + run: echo '' > target/doc/index.html + - name: Remove lock file + run: rm target/doc/.lock + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 + with: + path: target/doc + deploy_docs: + name: Deploy Docs + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + needs: build + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/Cargo.toml b/Cargo.toml index 937e352..d442b18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hip-rs" -version = "0.2.0" +version = "0.3.0" edition = "2021" authors = ["Anders Smedegaard Pedersen "] description = "A Rust wrapper for AMD's Heterogeneous-computing Interface for Portability (HIP), used for GPU interop." From b27b9ca3106b33d7240dc2b9dead70273b2e12f8 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Mon, 23 Dec 2024 13:36:49 +0100 Subject: [PATCH 06/11] 144 add GitHub pages (#149) * add docs step to release job * fix path * coverage installs llvm-cov --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .github/workflows/ci.yaml | 75 ++++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 906a035..d014683 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,43 +1,46 @@ name: CI on: - pull_request: - branches: ["main"] + pull_request: + branches: ["main"] env: - CARGO_TERM_COLOR: always + CARGO_TERM_COLOR: always jobs: - test: - name: Test - runs-on: self-hosted - steps: - - uses: actions/checkout@v4 - - - name: Rust Cache - uses: Swatinem/rust-cache@v2 - - - name: Check formatting - run: cargo fmt --all -- --check - - # - name: Clippy - # run: cargo clippy -- -D warnings - - - name: Run tests - run: cargo test --verbose - - coverage: - name: Coverage - runs-on: self-hosted - steps: - - uses: actions/checkout@v4 - - - name: Generate code coverage - run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info - - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v5 - with: - token: ${{ secrets.CODECOV_TOKEN }} - files: lcov.info - fail_ci_if_error: true + test: + name: Test + runs-on: self-hosted + steps: + - uses: actions/checkout@v4 + + - name: Rust Cache + uses: Swatinem/rust-cache@v2 + + - name: Check formatting + run: cargo fmt --all -- --check + + # - name: Clippy + # run: cargo clippy -- -D warnings + + - name: Run tests + run: cargo test --verbose + + coverage: + name: Coverage + runs-on: self-hosted + steps: + - uses: actions/checkout@v4 + + - name: Install llvm-cov + run: cargo install llvm-cov + + - name: Generate code coverage + run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: lcov.info + fail_ci_if_error: true From 631a586789c74fbe37861d4472a50702637fc1c1 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Mon, 23 Dec 2024 14:08:07 +0100 Subject: [PATCH 07/11] 144 add GitHub pages (#150) * add docs step to release job * fix path * coverage installs llvm-cov * fix install cargo-llvm-cov --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .github/workflows/ci.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d014683..935470f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -32,8 +32,8 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Install llvm-cov - run: cargo install llvm-cov + - name: Install cargo-llvm-cov + run: cargo install cargo-llvm-cov - name: Generate code coverage run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info From 82c975feaceabdba4dc4573a91ca5337a3e43650 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Mon, 23 Dec 2024 15:53:56 +0100 Subject: [PATCH 08/11] 144 add GitHub pages (#152) * add docs step to release job * fix path * coverage installs llvm-cov * fix install cargo-llvm-cov * fix path --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .github/workflows/ci.yaml | 3 --- .github/workflows/release.yaml | 7 +++++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 935470f..4e848a2 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -20,9 +20,6 @@ jobs: - name: Check formatting run: cargo fmt --all -- --check - # - name: Clippy - # run: cargo clippy -- -D warnings - - name: Run tests run: cargo test --verbose diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index a532f2b..5766c69 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -59,6 +59,13 @@ jobs: uses: actions/checkout@v4 with: ref: release + - name: Generate Docs + run: cargo doc --no-deps --document-private-items + - name: Add index.html + run: | + echo '' > target/doc/index.html + - name: Deploy + uses: peaceiris/actions-gh-pages@v3 - name: Setup Rust uses: dtolnay/rust-toolchain@stable - name: Configure cache From 29fc0edc7e97acc78f847a5d8943993cc77cafb7 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Mon, 30 Dec 2024 07:46:36 +0100 Subject: [PATCH 09/11] 153 add hipblas (#154) * move bindings to src/sys/. Add hipblas/ * add hipBLAS link to build.rs --- build.rs | 5 +- src/core/device.rs | 2 +- src/core/device_types.rs | 2 +- src/core/hip_call.rs | 3 +- src/core/init.rs | 2 +- src/core/mod.rs | 2 +- src/hipblas/handle.rs | 198 +++++++++++++++++++++++++++++++++++ src/hipblas/mod.rs | 7 ++ src/hipblas/types.rs | 50 +++++++++ src/lib.rs | 3 + src/{core => }/sys/mod.rs | 0 src/{core => }/sys/wrapper.h | 1 + 12 files changed, 269 insertions(+), 6 deletions(-) create mode 100644 src/hipblas/handle.rs create mode 100644 src/hipblas/mod.rs create mode 100644 src/hipblas/types.rs rename src/{core => }/sys/mod.rs (100%) rename src/{core => }/sys/wrapper.h (50%) diff --git a/build.rs b/build.rs index b00791b..8f638f8 100644 --- a/build.rs +++ b/build.rs @@ -12,6 +12,9 @@ fn main() { return; } + // link hipBLAS + println!("cargo:rustc-link-lib=dylib=hipblas"); + // Tell cargo when to rerun this build script println!("cargo:rerun-if-changed=src/core/sys/wrapper.h"); println!("cargo:rerun-if-changed=build.rs"); @@ -38,7 +41,7 @@ fn main() { fn generate_bindings(hip_include_path: &str) { let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); let bindings = bindgen::Builder::default() - .header("src/core/sys/wrapper.h") + .header("src/sys/wrapper.h") .clang_arg(&format!("-I{}", hip_include_path)) .clang_arg("-D__HIP_PLATFORM_AMD__") // Blocklist problematic items diff --git a/src/core/device.rs b/src/core/device.rs index ddb162e..1de64ab 100644 --- a/src/core/device.rs +++ b/src/core/device.rs @@ -1,5 +1,5 @@ -use super::sys; use super::{DeviceP2PAttribute, HipError, HipErrorKind, HipResult, MemPool, PCIBusId, Result}; +use crate::sys; use semver::Version; use std::ffi::CStr; use std::i32; diff --git a/src/core/device_types.rs b/src/core/device_types.rs index 8fe7987..bb7fdb1 100644 --- a/src/core/device_types.rs +++ b/src/core/device_types.rs @@ -1,5 +1,5 @@ -use super::sys; use super::{HipError, HipErrorKind, HipResult, Result}; +use crate::sys; use std::ffi::CStr; use std::i32; diff --git a/src/core/hip_call.rs b/src/core/hip_call.rs index 5bb0354..86b6e72 100644 --- a/src/core/hip_call.rs +++ b/src/core/hip_call.rs @@ -1,4 +1,5 @@ -use super::{sys, HipResult, Result}; +use super::{HipResult, Result}; +use crate::sys; #[macro_export] macro_rules! hip_call { diff --git a/src/core/init.rs b/src/core/init.rs index 78c6496..9006282 100644 --- a/src/core/init.rs +++ b/src/core/init.rs @@ -1,4 +1,4 @@ -use super::sys; +use crate::sys; use crate::{HipResult, Result}; use semver::Version; use std::i32; diff --git a/src/core/mod.rs b/src/core/mod.rs index 2e30cac..197ae87 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -6,8 +6,8 @@ mod init; mod memory; mod result; mod stream; -pub mod sys; +// use crate::sys::*; // Re-export core functionality pub use device::*; pub use device_types::*; diff --git a/src/hipblas/handle.rs b/src/hipblas/handle.rs new file mode 100644 index 0000000..1d25c17 --- /dev/null +++ b/src/hipblas/handle.rs @@ -0,0 +1,198 @@ +use crate::sys; +use crate::{HipResult, Result}; +use std::fmt; + +/// A handle to a hipBLAS library context. +/// +/// This handle is required for all hipBLAS library calls and encapsulates the +/// hipBLAS library context. The context includes the HIP device number and +/// stream used for all hipBLAS operations using this handle. +/// +/// # Thread Safety +/// +/// The handle is thread-safe and can be shared between threads. It implements +/// Send and Sync traits. +/// +/// # Examples +/// +/// ``` +/// use hip_rs::BlasHandle; +/// +/// let handle = BlasHandle::new().unwrap(); +/// // Use handle for hipBLAS operations +/// ``` +#[derive(Debug)] +pub struct BlasHandle { + handle: sys::hipblasHandle_t, +} + +impl BlasHandle { + /// Creates a new hipBLAS library context. + /// + /// # Returns + /// + /// * `Ok(BlasHandle)` - A new handle for hipBLAS operations + /// * `Err(HipError)` - If handle creation fails + /// + /// # Examples + /// + /// ``` + /// use hip_rs::BlasHandle; + /// + /// let handle = BlasHandle::new().unwrap(); + /// ``` + pub fn new() -> Result { + let mut handle = std::ptr::null_mut(); + unsafe { + let status = sys::hipblasCreate(&mut handle); + (Self { handle }, status).to_result() + } + } + + /// Returns the raw hipBLAS handle. + /// + /// # Safety + /// + /// The returned handle should not be destroyed manually or used after + /// the BlasHandle is dropped. + pub fn handle(&self) -> sys::hipblasHandle_t { + self.handle + } +} + +// Implement Drop to clean up the handle +impl Drop for BlasHandle { + fn drop(&mut self) { + if !self.handle.is_null() { + unsafe { + let status = sys::hipblasDestroy(self.handle); + if status != 0 { + log::error!("Failed to destroy hipBLAS handle: {}", status); + } + } + } + } +} + +// Implement Display for better error messages +impl fmt::Display for BlasHandle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "BlasHandle({:p})", self.handle) + } +} + +// Implement Send and Sync as hipBLAS handles are thread-safe +unsafe impl Send for BlasHandle {} +unsafe impl Sync for BlasHandle {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_handle_create() { + let handle = BlasHandle::new(); + assert!(handle.is_ok(), "Failed to create BlasHandle"); + let handle = handle.unwrap(); + assert!(!handle.handle().is_null(), "Handle is null after creation"); + } + + #[test] + fn test_handle_drop() { + let handle = BlasHandle::new().unwrap(); + drop(handle); // Should not panic or cause memory leaks + } + + #[test] + fn test_multiple_handles() { + // Create multiple handles to ensure they don't interfere + let handle1 = BlasHandle::new().unwrap(); + let handle2 = BlasHandle::new().unwrap(); + + assert!(!handle1.handle().is_null()); + assert!(!handle2.handle().is_null()); + assert_ne!( + handle1.handle(), + handle2.handle(), + "Handles should be unique" + ); + } + + #[test] + fn test_handle_clone_not_implemented() { + let handle = BlasHandle::new().unwrap(); + // This should fail to compile if you try to uncomment it + // let _cloned = handle.clone(); + } + + #[test] + fn test_handle_send_sync() { + // Test that handle can be sent between threads + let handle = BlasHandle::new().unwrap(); + let handle_ptr = handle.handle(); + + let handle = std::thread::spawn(move || { + assert!(!handle.handle().is_null()); + handle + }) + .join() + .unwrap(); + + assert_eq!(handle.handle(), handle_ptr); + } + + #[test] + fn test_handle_concurrent_use() { + use std::sync::Arc; + use std::thread; + + let handle = Arc::new(BlasHandle::new().unwrap()); + let mut threads = vec![]; + + // Spawn multiple threads using the same handle + for _ in 0..4 { + let handle_clone = Arc::clone(&handle); + threads.push(thread::spawn(move || { + assert!(!handle_clone.handle().is_null()); + })); + } + + // Wait for all threads to complete + for thread in threads { + thread.join().unwrap(); + } + } + + #[test] + fn test_handle_in_closure() { + let handle = BlasHandle::new().unwrap(); + let closure = || { + assert!(!handle.handle().is_null()); + }; + closure(); + } + + #[test] + fn test_handle_debug_format() { + let handle = BlasHandle::new().unwrap(); + let debug_str = format!("{:?}", handle); + assert!(!debug_str.is_empty(), "Debug formatting failed"); + println!("Debug format of BlasHandle: {}", debug_str); + } + + #[test] + fn test_handle_memory_stress() { + // Create and destroy multiple handles in a loop + for _ in 0..100 { + let handle = BlasHandle::new().unwrap(); + assert!(!handle.handle().is_null()); + drop(handle); + } + } + + #[test] + fn test_handle_null_check() { + let handle = BlasHandle::new().unwrap(); + assert!(!handle.handle().is_null(), "Handle should not be null"); + } +} diff --git a/src/hipblas/mod.rs b/src/hipblas/mod.rs new file mode 100644 index 0000000..7a0a0c5 --- /dev/null +++ b/src/hipblas/mod.rs @@ -0,0 +1,7 @@ +mod handle; +mod types; + +use crate::sys; + +pub use handle::*; +pub use types::*; diff --git a/src/hipblas/types.rs b/src/hipblas/types.rs new file mode 100644 index 0000000..bd7dc07 --- /dev/null +++ b/src/hipblas/types.rs @@ -0,0 +1,50 @@ +use crate::sys; + +#[repr(u32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Operation { + None = 0, // HIPBLAS_OP_N + Transpose = 1, // HIPBLAS_OP_T + Conjugate = 2, // HIPBLAS_OP_C +} + +impl From for sys::hipblasOperation_t { + fn from(op: Operation) -> Self { + op as sys::hipblasOperation_t + } +} + +#[repr(u32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Status { + Success = 0, + Handle = 1, + NotInitialized = 2, + InvalidValue = 3, + ArchMismatch = 4, + MappingError = 5, + ExecutionFailed = 6, + InternalError = 7, + NotSupported = 8, + MemoryError = 9, + AllocationFailed = 10, +} + +impl From for Status { + fn from(status: sys::hipblasStatus_t) -> Self { + match status { + 0 => Status::Success, + 1 => Status::Handle, + 2 => Status::NotInitialized, + 3 => Status::InvalidValue, + 4 => Status::ArchMismatch, + 5 => Status::MappingError, + 6 => Status::ExecutionFailed, + 7 => Status::InternalError, + 8 => Status::NotSupported, + 9 => Status::MemoryError, + 10 => Status::AllocationFailed, + _ => Status::InternalError, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index a5b8960..7a8a26e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,7 @@ #![allow(non_upper_case_globals)] mod core; +mod hipblas; +mod sys; pub use core::*; +pub use hipblas::*; diff --git a/src/core/sys/mod.rs b/src/sys/mod.rs similarity index 100% rename from src/core/sys/mod.rs rename to src/sys/mod.rs diff --git a/src/core/sys/wrapper.h b/src/sys/wrapper.h similarity index 50% rename from src/core/sys/wrapper.h rename to src/sys/wrapper.h index 016836d..b4b0b1a 100644 --- a/src/core/sys/wrapper.h +++ b/src/sys/wrapper.h @@ -1 +1,2 @@ #include +#include From 49ce1986f3a88e8fa92721ff76fe5f92e8922116 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Mon, 30 Dec 2024 18:23:15 +0100 Subject: [PATCH 10/11] 155 add gemm (#156) * add src/result.rs. Refactor * add hipblas/types/Complex32 --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/core/device.rs | 36 ++-- src/core/device_types.rs | 6 +- src/core/hip_call.rs | 8 +- src/core/init.rs | 7 +- src/core/memory.rs | 34 ++-- src/core/result.rs | 161 +++++++++------ src/core/stream.rs | 8 +- src/hipblas/gemm.rs | 413 +++++++++++++++++++++++++++++++++++++++ src/hipblas/handle.rs | 10 +- src/hipblas/mod.rs | 6 +- src/hipblas/result.rs | 163 +++++++++++++++ src/hipblas/types.rs | 122 +++++++++--- src/lib.rs | 2 + src/result.rs | 144 ++++++++++++++ 14 files changed, 987 insertions(+), 133 deletions(-) create mode 100644 src/hipblas/gemm.rs create mode 100644 src/hipblas/result.rs create mode 100644 src/result.rs diff --git a/src/core/device.rs b/src/core/device.rs index 1de64ab..9cabea3 100644 --- a/src/core/device.rs +++ b/src/core/device.rs @@ -1,4 +1,6 @@ -use super::{DeviceP2PAttribute, HipError, HipErrorKind, HipResult, MemPool, PCIBusId, Result}; +use super::result::{HipError, HipResult, HipStatus}; +use super::{DeviceP2PAttribute, MemPool, PCIBusId}; +use crate::result::ResultExt; use crate::sys; use semver::Version; use std::ffi::CStr; @@ -42,7 +44,7 @@ impl Device { /// # Returns /// * `Result` - On success, returns a `Version` struct containing the major and minor version /// numbers of the device's compute capability. On failure, returns an error indicating what went wrong. - pub fn device_compute_capability(&self) -> Result { + pub fn device_compute_capability(&self) -> HipResult { unsafe { let mut major: i32 = -1; let mut minor: i32 = -1; @@ -61,7 +63,7 @@ impl Device { /// Returns `HipError` if: /// * The device is invalid /// * The runtime is not initialized - pub fn device_total_mem(&self) -> Result { + pub fn device_total_mem(&self) -> HipResult { unsafe { let mut size: usize = 0; let code = sys::hipDeviceTotalMem(&mut size, self.id); @@ -80,7 +82,7 @@ impl Device { /// * The device ID is invalid /// * There was an error retrieving the device name /// * The name string could not be converted to valid UTF-8 - pub fn get_device_name(&self) -> Result { + pub fn get_device_name(&self) -> HipResult { const buffer_size: usize = 64; let mut buffer = vec![0i8; buffer_size]; @@ -105,7 +107,7 @@ impl Device { /// * The device is invalid /// * The runtime is not initialized /// * There was an error retrieving the UUID - fn get_device_uuid_bytes(&self) -> Result<[i8; 16]> { + fn get_device_uuid_bytes(&self) -> HipResult<[i8; 16]> { let mut hip_bytes = sys::hipUUID_t { bytes: [0; 16] }; unsafe { let code = sys::hipDeviceGetUuid(&mut hip_bytes, self.id); @@ -128,7 +130,7 @@ impl Device { /// * The device is invalid /// * The runtime is not initialized /// * There was an error retrieving the UUID - pub fn get_device_uuid(&self) -> Result { + pub fn get_device_uuid(&self) -> HipResult { Self::get_device_uuid_bytes(self).map(|bytes| { let uuid_bytes: [u8; 16] = bytes.map(|b| b as u8); Uuid::from_bytes(uuid_bytes) @@ -148,7 +150,7 @@ impl Device { /// * The device is invalid /// * The runtime is not initialized /// * There was an error retrieving the PCI bus ID - pub fn get_device_pci_bus_id(&self) -> Result { + pub fn get_device_pci_bus_id(&self) -> HipResult { let mut pci_bus_id = PCIBusId::new(); unsafe { let code = @@ -167,7 +169,7 @@ impl Device { /// * The device ID is invalid /// * The operation is not supported on this device/platform /// * There was an error retrieving the memory pool - pub fn get_default_mem_pool(&self) -> Result { + pub fn get_default_mem_pool(&self) -> HipResult { let mut mem_pool = std::ptr::null_mut(); unsafe { let code = sys::hipDeviceGetDefaultMemPool(&mut mem_pool, self.id); @@ -191,7 +193,7 @@ impl Device { /// Returns `HipError` if: /// * No device is currently active /// * The HIP runtime is not initialized -pub fn synchronize() -> Result<()> { +pub fn synchronize() -> HipResult<()> { unsafe { let code = sys::hipDeviceSynchronize(); ((), code).to_result() @@ -207,7 +209,7 @@ pub fn synchronize() -> Result<()> { /// Returns `HipError` if: /// * The runtime is not initialized (`HipErrorKind::NotInitialized`) /// * The operation fails for other reasons -pub fn get_device_count() -> Result { +pub fn get_device_count() -> HipResult { unsafe { let mut count = 0; let code = sys::hipGetDeviceCount(&mut count); @@ -239,7 +241,7 @@ pub fn get_device_p2p_attribute( attr: DeviceP2PAttribute, src_device: Device, dst_device: Device, -) -> Result { +) -> HipResult { let mut value = -1; unsafe { let code = @@ -260,7 +262,7 @@ pub fn get_device_p2p_attribute( /// * No device is currently active /// * HIP runtime is not initialized /// * There was an error accessing device information -pub fn get_device() -> Result { +pub fn get_device() -> HipResult { unsafe { let mut device_id: i32 = -1; let code = sys::hipGetDevice(&mut device_id); @@ -285,7 +287,7 @@ pub fn get_device() -> Result { /// * The device ID is invalid (greater than or equal to device count) /// * The HIP runtime is not initialized /// * The specified device has encountered a previous error and is in a broken state -pub fn set_device(device: Device) -> Result { +pub fn set_device(device: Device) -> HipResult { unsafe { let code = sys::hipSetDevice(device.id); (device, code).to_result() @@ -305,7 +307,7 @@ pub fn set_device(device: Device) -> Result { /// * The PCI bus ID string is invalid /// * No device with the specified PCI bus ID exists /// * The runtime is not initialized -pub fn get_device_by_pci_bus_id(mut pci_bus_id: PCIBusId) -> Result { +pub fn get_device_by_pci_bus_id(mut pci_bus_id: PCIBusId) -> HipResult { let mut device_id = i32::MAX; unsafe { let code = sys::hipDeviceGetByPCIBusId(&mut device_id, pci_bus_id.as_mut_ptr()); @@ -330,7 +332,7 @@ mod tests { } Err(e) => { // Check if the error is "not supported" which is acceptable - if e.kind != HipErrorKind::NotSupported { + if e.status != HipStatus::NotSupported { panic!("Unexpected error getting default memory pool: {:?}", e); } println!("Memory pools not supported on this device/platform"); @@ -369,7 +371,7 @@ mod tests { let invalid_device = Device::new(99); let result = invalid_device.get_device_pci_bus_id(); assert!(result.is_err()); - assert_eq!(result.unwrap_err().kind, HipErrorKind::InvalidDevice); + assert_eq!(result.unwrap_err().status, HipStatus::InvalidDevice); } #[test] @@ -452,6 +454,6 @@ mod tests { let invalid_device = Device::new(99); let result = set_device(invalid_device); assert!(result.is_err()); - assert_eq!(result.unwrap_err().kind, HipErrorKind::InvalidDevice); + assert_eq!(result.unwrap_err().status, HipStatus::InvalidDevice); } } diff --git a/src/core/device_types.rs b/src/core/device_types.rs index bb7fdb1..d7bae87 100644 --- a/src/core/device_types.rs +++ b/src/core/device_types.rs @@ -1,4 +1,4 @@ -use super::{HipError, HipErrorKind, HipResult, Result}; +use super::result::{HipError, HipResult, HipStatus}; use crate::sys; use std::ffi::CStr; use std::i32; @@ -86,7 +86,7 @@ impl From for u32 { impl TryFrom for DeviceP2PAttribute { type Error = HipError; - fn try_from(value: sys::hipDeviceP2PAttr) -> Result { + fn try_from(value: sys::hipDeviceP2PAttr) -> Result { match value { sys::hipDeviceP2PAttr_hipDevP2PAttrPerformanceRank => Ok(Self::PerformanceRank), sys::hipDeviceP2PAttr_hipDevP2PAttrAccessSupported => Ok(Self::AccessSupported), @@ -96,7 +96,7 @@ impl TryFrom for DeviceP2PAttribute { sys::hipDeviceP2PAttr_hipDevP2PAttrHipArrayAccessSupported => { Ok(Self::HipArrayAccessSupported) } - _ => Err(HipError::from_kind(HipErrorKind::InvalidValue)), + _ => Err(HipError::from_status(HipStatus::InvalidValue)), } } } diff --git a/src/core/hip_call.rs b/src/core/hip_call.rs index 86b6e72..41ad169 100644 --- a/src/core/hip_call.rs +++ b/src/core/hip_call.rs @@ -1,11 +1,13 @@ -use super::{HipResult, Result}; +use super::result::{HipResult, HipStatus}; +use crate::result::ResultExt; use crate::sys; #[macro_export] macro_rules! hip_call { ($call:expr) => {{ - let code = unsafe { $call }; - ((), code).to_result() + let code: u32 = unsafe { $call }; + let result: HipResult<()> = ((), code).to_result(); + result }}; } diff --git a/src/core/init.rs b/src/core/init.rs index 9006282..66b2a7e 100644 --- a/src/core/init.rs +++ b/src/core/init.rs @@ -1,5 +1,6 @@ +use super::result::{HipResult, HipStatus}; +pub use crate::result::ResultExt; use crate::sys; -use crate::{HipResult, Result}; use semver::Version; use std::i32; @@ -15,7 +16,7 @@ use std::i32; /// Returns `HipError` if: /// * The runtime fails to initialize /// * The runtime is already initialized -pub fn initialize() -> Result<()> { +pub fn initialize() -> HipResult<()> { unsafe { let code = sys::hipInit(0); ((), code).to_result() @@ -50,7 +51,7 @@ fn decode_hip_version(version: i32) -> Version { /// Returns `HipError` if: /// * The runtime is not initialized /// * Getting the version fails -pub fn runtime_get_version() -> Result { +pub fn runtime_get_version() -> HipResult { unsafe { let mut version: i32 = -1; let code = sys::hipRuntimeGetVersion(&mut version); diff --git a/src/core/memory.rs b/src/core/memory.rs index e0e2deb..2a2cd85 100644 --- a/src/core/memory.rs +++ b/src/core/memory.rs @@ -1,6 +1,8 @@ use super::flags::DeviceMallocFlag; -use super::{HipError, HipErrorKind, HipResult, Result, Stream}; +use super::result::{HipError, HipResult, HipStatus}; +use crate::result::ResultExt; use crate::sys; +use crate::Stream; /// A wrapper for device memory allocated on the GPU. /// Automatically frees the memory when dropped. @@ -166,7 +168,7 @@ unsafe fn memory_copy( src: *const std::ffi::c_void, size: usize, kind: MemoryCopyKind, -) -> Result<()> { +) -> HipResult<()> { let code = sys::hipMemcpy(dst, src, size, kind.into()); ((), code).to_result() } @@ -176,7 +178,7 @@ impl MemoryPointer { /// memory allocation functions. /// /// Takes the size to allocate and - fn allocate_with_fn(size: usize, alloc_fn: F) -> Result + fn allocate_with_fn(size: usize, alloc_fn: F) -> HipResult where F: FnOnce(*mut *mut std::ffi::c_void, usize) -> u32, { @@ -214,7 +216,7 @@ impl MemoryPointer { /// * `Ok(MemoryPointer)` - Handle to allocated device memory /// * `Err(HipError)` - Error occurred during allocation /// ``` - pub fn alloc(size: usize) -> Result { + pub fn alloc(size: usize) -> HipResult { Self::allocate_with_fn(size, |ptr, size| unsafe { sys::hipMalloc(ptr, size) }) } @@ -233,7 +235,7 @@ impl MemoryPointer { /// * If size is 0, returns null pointer with success status /// * Invalid flags will result in hipErrorInvalidValue error /// - pub fn alloc_with_flag(size: usize, flag: DeviceMallocFlag) -> Result { + pub fn alloc_with_flag(size: usize, flag: DeviceMallocFlag) -> HipResult { Self::allocate_with_fn(size, |ptr, size| unsafe { sys::hipExtMallocWithFlags(ptr, size, flag.bits()) }) @@ -258,7 +260,7 @@ impl MemoryPointer { /// /// let ptr = MemoryPointer::::alloc_async(1024, &stream).unwrap(); /// ``` - pub fn alloc_async(size: usize, stream: &Stream) -> Result { + pub fn alloc_async(size: usize, stream: &Stream) -> HipResult { Self::allocate_with_fn(size, |ptr, size| unsafe { sys::hipMallocAsync(ptr, size, stream.handle()) }) @@ -288,15 +290,15 @@ impl MemoryPointer { /// - Checks that neither pointer is null /// - Validates that destination has sufficient size /// - Ensures proper size alignment - pub fn copy_to(&self, destination: &MemoryPointer, kind: MemoryCopyKind) -> Result<()> { + pub fn copy_to(&self, destination: &MemoryPointer, kind: MemoryCopyKind) -> HipResult<()> { // Check for null pointers if self.pointer.is_null() || destination.pointer.is_null() { - return Err(HipError::from_kind(HipErrorKind::InvalidValue)); + return Err(HipError::from_status(HipStatus::InvalidValue)); } // Check that destination has sufficient size if destination.size < self.size { - return Err(HipError::from_kind(HipErrorKind::InvalidValue)); + return Err(HipError::from_status(HipStatus::InvalidValue)); } // Calculate total bytes to copy @@ -319,7 +321,7 @@ impl MemoryPointer { /// * `size` - Number of bytes to fill. Must not exceed the allocated size. /// /// # Returns - /// * `Result<()>` - Success or error status + /// * `HipResult<()>` - Success or error status /// /// # Examples /// ``` @@ -328,10 +330,10 @@ impl MemoryPointer { /// let mut ptr = MemoryPointer::::alloc(1024).unwrap(); /// ptr.memset(0, 1024).unwrap(); // Zero-initialize memory /// ``` - pub fn memset(&self, value: u8, size: usize) -> Result<()> { + pub fn memset(&self, value: u8, size: usize) -> HipResult<()> { // Validate size doesn't exceed allocation if size > self.size { - return Err(HipError::from_kind(HipErrorKind::InvalidValue)); + return Err(HipError::from_status(HipStatus::InvalidValue)); } if size == 0 { @@ -352,7 +354,7 @@ impl Drop for MemoryPointer { let code = sys::hipFree(self.pointer as *mut std::ffi::c_void); if code != 0 { let error = HipError::new(code); - log::error!("MemoryPointer failed to free memory: {}", error); + log::error!("MemoryPointer failed to free memory: {:?}", error); } } } @@ -374,7 +376,7 @@ impl MemPool { } } - pub fn create(props: MemPoolProps) -> Result { + pub fn create(props: MemPoolProps) -> HipResult { let mut handle = std::ptr::null_mut(); let sys_props = props.to_sys_props(); @@ -437,7 +439,7 @@ impl From for u32 { impl TryFrom for MemoryCopyKind { type Error = HipError; - fn try_from(value: sys::hipMemcpyKind) -> Result { + fn try_from(value: sys::hipMemcpyKind) -> HipResult { match value { 0 => Ok(Self::HostToHost), 1 => Ok(Self::HostToDevice), @@ -445,7 +447,7 @@ impl TryFrom for MemoryCopyKind { 3 => Ok(Self::DeviceToDevice), 4 => Ok(Self::Default), 1024 => Ok(Self::DeviceToDeviceNoCU), - _ => Err(HipError::from_kind(HipErrorKind::InvalidValue)), + _ => Err(HipError::from_status(HipStatus::InvalidValue)), } } } diff --git a/src/core/result.rs b/src/core/result.rs index 7f238a6..7c1ab4c 100644 --- a/src/core/result.rs +++ b/src/core/result.rs @@ -1,23 +1,9 @@ -use std::fmt; +use crate::result::{ResultExt, StatusCode}; -/// Success code from HIP runtime #[repr(u32)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum HipSuccess { +pub enum HipStatus { Success = 0, -} - -impl HipSuccess { - pub fn new() -> Self { - Self::Success - } -} - -/// Error codes from HIP runtime -/// https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/hip__runtime__api_8h.html#a657deda9809cdddcbfcd336a29894635 -#[repr(u32)] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum HipErrorKind { InvalidValue = 1, MemoryAllocation = 2, NotInitialized = 3, @@ -29,73 +15,136 @@ pub enum HipErrorKind { Unknown = 999, } -impl HipErrorKind { - /// Convert from raw HIP error code to HipErrorKind - pub fn from_raw(error: u32) -> Self { - match error { - 1 => HipErrorKind::InvalidValue, - 2 => HipErrorKind::MemoryAllocation, - 3 => HipErrorKind::NotInitialized, - 4 => HipErrorKind::Deinitialized, - 101 => HipErrorKind::InvalidDevice, - 301 => HipErrorKind::FileNotFound, - 600 => HipErrorKind::NotReady, - 801 => HipErrorKind::NotSupported, - _ => HipErrorKind::Unknown, +impl HipStatus { + fn from(status: u32) -> Self { + match status { + 0 => HipStatus::Success, + 1 => HipStatus::InvalidValue, + 2 => HipStatus::MemoryAllocation, + 3 => HipStatus::NotInitialized, + 4 => HipStatus::Deinitialized, + 101 => HipStatus::InvalidDevice, + 301 => HipStatus::FileNotFound, + 600 => HipStatus::NotReady, + 801 => HipStatus::NotSupported, + _ => HipStatus::Unknown, } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct HipError { - pub kind: HipErrorKind, + pub status: HipStatus, pub code: u32, } impl HipError { - pub fn new(code: u32) -> Self { + pub(crate) fn new(code: u32) -> Self { Self { - kind: HipErrorKind::from_raw(code), + status: HipStatus::from(code), code, } } - pub fn from_kind(kind: HipErrorKind) -> Self { + pub fn from_status(status: HipStatus) -> Self { Self { - kind, - code: kind as u32, + status, + code: status as u32, } } } -impl fmt::Display for HipError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "HIP error: {:?} (code: {})", self.kind, self.code) +impl StatusCode for HipError { + fn is_success(&self) -> bool { + self.status == HipStatus::Success } -} - -impl std::error::Error for HipError {} -pub type Result = std::result::Result; + fn code(&self) -> u32 { + self.code as u32 + } -/// Trait for checking HIP operation results -pub trait HipResult { - /// The successful value type - type Value; + fn kind_str(&self) -> &'static str { + "HIP" + } - /// Convert HIP result to Result type - fn to_result(self) -> Result; + fn status_str(&self) -> &'static str { + match self.status { + HipStatus::Success => "Success", + HipStatus::InvalidValue => "InvalidValue", + HipStatus::MemoryAllocation => "MemoryAllocation", + HipStatus::NotInitialized => "NotInitialized", + HipStatus::Deinitialized => "Deinitialized", + HipStatus::InvalidDevice => "InvalidDevice", + HipStatus::FileNotFound => "FileNotFound", + HipStatus::NotReady => "NotReady", + HipStatus::NotSupported => "NotSupported", + HipStatus::Unknown => "Unknown", + } + } } -/// Implement for tuple of (value, error_code) -impl HipResult for (T, u32) { +pub type HipResult = std::result::Result; + +impl ResultExt for (T, u32) { type Value = T; + fn to_result(self) -> HipResult { + let (value, status) = self; + (value, HipError::new(status)).to_result() + } +} - fn to_result(self) -> Result { - let (value, code) = self; - match code { - 0 => Ok(value), - _ => Err(HipError::new(code)), - } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hip_status_from() { + assert_eq!(HipStatus::from(0), HipStatus::Success); + assert_eq!(HipStatus::from(1), HipStatus::InvalidValue); + assert_eq!(HipStatus::from(2), HipStatus::MemoryAllocation); + assert_eq!(HipStatus::from(3), HipStatus::NotInitialized); + assert_eq!(HipStatus::from(4), HipStatus::Deinitialized); + assert_eq!(HipStatus::from(101), HipStatus::InvalidDevice); + assert_eq!(HipStatus::from(301), HipStatus::FileNotFound); + assert_eq!(HipStatus::from(600), HipStatus::NotReady); + assert_eq!(HipStatus::from(801), HipStatus::NotSupported); + assert_eq!(HipStatus::from(1000), HipStatus::Unknown); + } + + #[test] + fn test_hip_error_new() { + let error = HipError::new(1); + assert_eq!(error.status, HipStatus::InvalidValue); + assert_eq!(error.code, 1); + } + + #[test] + fn test_hip_error_from_status() { + let error = HipError::from_status(HipStatus::InvalidValue); + assert_eq!(error.status, HipStatus::InvalidValue); + assert_eq!(error.code, 1); + } + + #[test] + fn test_hip_error_status_code() { + let error = HipError::new(0); + assert!(error.is_success()); + assert_eq!(error.code(), 0); + assert_eq!(error.kind_str(), "HIP"); + + let error = HipError::new(1); + assert!(!error.is_success()); + assert_eq!(error.code(), 1); + } + + #[test] + fn test_result_ext() { + let success: HipResult = (42, 0).to_result(); + assert!(success.is_ok()); + assert_eq!(success.unwrap(), 42); + + let error: HipResult = (42, 1).to_result(); + assert!(error.is_err()); + assert_eq!(error.unwrap_err().code, 1); } } diff --git a/src/core/stream.rs b/src/core/stream.rs index 21a4070..171b546 100644 --- a/src/core/stream.rs +++ b/src/core/stream.rs @@ -1,4 +1,6 @@ -use crate::{sys, HipError, HipResult, Result}; +use super::result::{HipResult, HipStatus}; +use crate::result::ResultExt; +use crate::sys; /// A handle to a HIP stream that executes commands in order. #[derive(Debug)] @@ -22,7 +24,7 @@ impl Stream { /// /// let stream = Stream::create().unwrap(); /// ``` - pub fn create() -> Result { + pub fn create() -> HipResult { let mut stream: sys::hipStream_t = std::ptr::null_mut(); unsafe { let code = sys::hipStreamCreate(&mut stream); @@ -49,7 +51,7 @@ impl Stream { /// * `Err(HipError)` - Either: /// - `HipErrorKind::NotReady` if operations are still in progress /// - `HipErrorKind::InvalidHandle` if the stream handle is invalid - pub fn query_stream(&self) -> Result<()> { + pub fn query_stream(&self) -> HipResult<()> { unsafe { let code = sys::hipStreamQuery(self.handle); ((), code).to_result() diff --git a/src/hipblas/gemm.rs b/src/hipblas/gemm.rs new file mode 100644 index 0000000..7c251a8 --- /dev/null +++ b/src/hipblas/gemm.rs @@ -0,0 +1,413 @@ +use super::{BlasHandle, Operation, Result}; +use crate::result::ResultExt; +use crate::Complex32; +use crate::{sys, MemoryPointer}; + +/// Trait for types supported by GEMM operations +pub trait GemmDatatype { + /// Calls the appropriate HIPBLAS GEMM function for this datatype + unsafe fn hipblas_gemm( + handle: sys::hipblasHandle_t, + trans_a: sys::hipblasOperation_t, + trans_b: sys::hipblasOperation_t, + m: i32, + n: i32, + k: i32, + alpha: *const Self, + a: *const Self, + lda: i32, + b: *const Self, + ldb: i32, + beta: *const Self, + c: *mut Self, + ldc: i32, + ) -> sys::hipblasStatus_t; +} + +// u16 +impl GemmDatatype for sys::hipblasHalf { + unsafe fn hipblas_gemm( + handle: sys::hipblasHandle_t, + trans_a: sys::hipblasOperation_t, + trans_b: sys::hipblasOperation_t, + m: i32, + n: i32, + k: i32, + alpha: *const Self, + a: *const Self, + lda: i32, + b: *const Self, + ldb: i32, + beta: *const Self, + c: *mut Self, + ldc: i32, + ) -> sys::hipblasStatus_t { + sys::hipblasHgemm( + handle, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + ) + } +} + +impl GemmDatatype for f32 { + unsafe fn hipblas_gemm( + handle: sys::hipblasHandle_t, + trans_a: sys::hipblasOperation_t, + trans_b: sys::hipblasOperation_t, + m: i32, + n: i32, + k: i32, + alpha: *const Self, + a: *const Self, + lda: i32, + b: *const Self, + ldb: i32, + beta: *const Self, + c: *mut Self, + ldc: i32, + ) -> sys::hipblasStatus_t { + sys::hipblasSgemm( + handle, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + ) + } +} + +impl GemmDatatype for f64 { + unsafe fn hipblas_gemm( + handle: sys::hipblasHandle_t, + trans_a: sys::hipblasOperation_t, + trans_b: sys::hipblasOperation_t, + m: i32, + n: i32, + k: i32, + alpha: *const Self, + a: *const Self, + lda: i32, + b: *const Self, + ldb: i32, + beta: *const Self, + c: *mut Self, + ldc: i32, + ) -> sys::hipblasStatus_t { + sys::hipblasDgemm( + handle, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + ) + } +} + +impl GemmDatatype for Complex32 { + unsafe fn hipblas_gemm( + handle: sys::hipblasHandle_t, + trans_a: sys::hipblasOperation_t, + trans_b: sys::hipblasOperation_t, + m: i32, + n: i32, + k: i32, + alpha: *const Self, + a: *const Self, + lda: i32, + b: *const Self, + ldb: i32, + beta: *const Self, + c: *mut Self, + ldc: i32, + ) -> sys::hipblasStatus_t { + // Convert Complex32 pointers to hipblasComplex pointers + let alpha_ptr = alpha as *const sys::hipblasComplex; + let a_ptr = a as *const sys::hipblasComplex; + let b_ptr = b as *const sys::hipblasComplex; + let beta_ptr = beta as *const sys::hipblasComplex; + let c_ptr = c as *mut sys::hipblasComplex; + + sys::hipblasCgemm( + handle, trans_a, trans_b, m, n, k, alpha_ptr, a_ptr, lda, b_ptr, ldb, beta_ptr, c_ptr, + ldc, + ) + } +} + +impl GemmDatatype for sys::hipblasDoubleComplex { + unsafe fn hipblas_gemm( + handle: sys::hipblasHandle_t, + trans_a: sys::hipblasOperation_t, + trans_b: sys::hipblasOperation_t, + m: i32, + n: i32, + k: i32, + alpha: *const Self, + a: *const Self, + lda: i32, + b: *const Self, + ldb: i32, + beta: *const Self, + c: *mut Self, + ldc: i32, + ) -> sys::hipblasStatus_t { + sys::hipblasZgemm( + handle, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + ) + } +} + +/// Performs matrix-matrix multiplication: C = alpha * op(A) * op(B) + beta * C +/// +/// # Arguments +/// * `handle` - HIPBLAS library handle +/// * `trans_a` - How to transform matrix A +/// * `trans_b` - How to transform matrix B +/// * `m` - Number of rows in op(A) and C +/// * `n` - Number of columns in op(B) and C +/// * `k` - Number of columns in op(A) and rows in op(B) +/// * `alpha` - Scalar multiplier for AB +/// * `a` - Input matrix A +/// * `lda` - Leading dimension of A +/// * `b` - Input matrix B +/// * `ldb` - Leading dimension of B +/// * `beta` - Scalar multiplier for C +/// * `c` - Input/output matrix C +/// * `ldc` - Leading dimension of C +/// +/// # Returns +/// * `Ok(())` if successful +/// * `Err(HipError)` if operation failed +pub fn gemm( + handle: &BlasHandle, + trans_a: Operation, + trans_b: Operation, + m: i32, + n: i32, + k: i32, + alpha: &T, + a: &MemoryPointer, + lda: i32, + b: &MemoryPointer, + ldb: i32, + beta: &T, + c: &mut MemoryPointer, + ldc: i32, +) -> Result<()> { + unsafe { + let code = T::hipblas_gemm( + handle.handle(), + trans_a.into(), + trans_b.into(), + m, + n, + k, + alpha, + a.as_pointer(), + lda, + b.as_pointer(), + ldb, + beta, + c.as_pointer(), + ldc, + ); + ((), code).to_result() + } +} + +#[cfg(test)] +mod tests { + use crate::Complex32; + + use super::*; + + #[test] + fn test_hgemm() { + let handle = BlasHandle::new().unwrap(); + let m = 2; + let n = 2; + let k = 2; + + let a = MemoryPointer::::alloc(m as usize * k as usize).unwrap(); + let b = MemoryPointer::::alloc(k as usize * n as usize).unwrap(); + let mut c = MemoryPointer::::alloc(m as usize * n as usize).unwrap(); + + let alpha = 1.0 as u16; // 1.0 in half precision + let beta = 0.0 as u16; // 0.0 in half precision + + let result = gemm( + &handle, + Operation::None, + Operation::None, + m, + n, + k, + &alpha, + &a, + m, + &b, + k, + &beta, + &mut c, + m, + ); + assert!(result.is_ok()); + } + + #[test] + fn test_sgemm() { + let handle = BlasHandle::new().unwrap(); + let m = 2; + let n = 2; + let k = 2; + + let a = MemoryPointer::::alloc(m as usize * k as usize).unwrap(); + let b = MemoryPointer::::alloc(k as usize * n as usize).unwrap(); + let mut c = MemoryPointer::::alloc(m as usize * n as usize).unwrap(); + + let alpha: f32 = 1.0; + let beta: f32 = 0.0; + + let result = gemm( + &handle, + Operation::None, + Operation::None, + m, + n, + k, + &alpha, + &a, + m, + &b, + k, + &beta, + &mut c, + m, + ); + assert!(result.is_ok()); + } + + #[test] + fn test_dgemm() { + let handle = BlasHandle::new().unwrap(); + let m = 2; + let n = 2; + let k = 2; + + let a = MemoryPointer::::alloc(m as usize * k as usize).unwrap(); + let b = MemoryPointer::::alloc(k as usize * n as usize).unwrap(); + let mut c = MemoryPointer::::alloc(m as usize * n as usize).unwrap(); + + let alpha: f64 = 1.0; + let beta: f64 = 0.0; + + let result = gemm( + &handle, + Operation::None, + Operation::None, + m, + n, + k, + &alpha, + &a, + m, + &b, + k, + &beta, + &mut c, + m, + ); + assert!(result.is_ok()); + } + + #[test] + fn test_cgemm() { + let handle = BlasHandle::new().unwrap(); + let m = 2; + let n = 2; + let k = 2; + + let a = MemoryPointer::::alloc(m as usize * k as usize).unwrap(); + let b = MemoryPointer::::alloc(k as usize * n as usize).unwrap(); + let mut c = MemoryPointer::::alloc(m as usize * n as usize).unwrap(); + + let alpha = Complex32::new(1.0, 0.0); + let beta = Complex32::new(0.0, 0.0); + + let result = gemm( + &handle, + Operation::None, + Operation::None, + m, + n, + k, + &alpha, + &a, + m, + &b, + k, + &beta, + &mut c, + m, + ); + assert!(result.is_ok()); + } + + #[test] + fn test_zgemm() { + let handle = BlasHandle::new().unwrap(); + let m = 2; + let n = 2; + let k = 2; + + let a = MemoryPointer::::alloc(m as usize * k as usize).unwrap(); + let b = MemoryPointer::::alloc(k as usize * n as usize).unwrap(); + let mut c = + MemoryPointer::::alloc(m as usize * n as usize).unwrap(); + + let alpha = sys::hipblasDoubleComplex { x: 1.0, y: 0.0 }; + let beta = sys::hipblasDoubleComplex { x: 0.0, y: 0.0 }; + + let result = gemm( + &handle, + Operation::None, + Operation::None, + m, + n, + k, + &alpha, + &a, + m, + &b, + k, + &beta, + &mut c, + m, + ); + assert!(result.is_ok()); + } + + #[test] + fn test_gemm_error() { + let handle = BlasHandle::new().unwrap(); + let m = -1; // Invalid dimension + let n = 2; + let k = 2; + + let a = MemoryPointer::::alloc(4).unwrap(); + let b = MemoryPointer::::alloc(4).unwrap(); + let mut c = MemoryPointer::::alloc(4).unwrap(); + + let alpha: f32 = 1.0; + let beta: f32 = 0.0; + + let result = gemm( + &handle, + Operation::None, + Operation::None, + m, + n, + k, + &alpha, + &a, + m, + &b, + k, + &beta, + &mut c, + m, + ); + assert!(result.is_err()); + } +} diff --git a/src/hipblas/handle.rs b/src/hipblas/handle.rs index 1d25c17..f67e0a2 100644 --- a/src/hipblas/handle.rs +++ b/src/hipblas/handle.rs @@ -1,5 +1,6 @@ +use super::Result; +use crate::result::ResultExt; use crate::sys; -use crate::{HipResult, Result}; use std::fmt; /// A handle to a hipBLAS library context. @@ -118,13 +119,6 @@ mod tests { ); } - #[test] - fn test_handle_clone_not_implemented() { - let handle = BlasHandle::new().unwrap(); - // This should fail to compile if you try to uncomment it - // let _cloned = handle.clone(); - } - #[test] fn test_handle_send_sync() { // Test that handle can be sent between threads diff --git a/src/hipblas/mod.rs b/src/hipblas/mod.rs index 7a0a0c5..31dc80e 100644 --- a/src/hipblas/mod.rs +++ b/src/hipblas/mod.rs @@ -1,7 +1,9 @@ +mod gemm; mod handle; +mod result; mod types; -use crate::sys; - +pub use gemm::*; pub use handle::*; +pub use result::*; pub use types::*; diff --git a/src/hipblas/result.rs b/src/hipblas/result.rs new file mode 100644 index 0000000..98e4d19 --- /dev/null +++ b/src/hipblas/result.rs @@ -0,0 +1,163 @@ +use crate::result::{ResultExt, StatusCode}; + +#[repr(u32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BlasStatus { + Success = 0, + NotInitialized = 1, + AllocationFailed = 2, + InvalidValue = 3, + MappingError = 4, + ExecutionFailed = 5, + InternalError = 6, + NotSupported = 7, + ArchMismatch = 8, + HandleIsNullPointer = 9, + InvalidEnum = 10, + Unknown = 11, +} + +impl BlasStatus { + fn from(status: u32) -> Self { + match status { + 0 => BlasStatus::Success, + 1 => BlasStatus::NotInitialized, + 2 => BlasStatus::AllocationFailed, + 3 => BlasStatus::InvalidValue, + 4 => BlasStatus::MappingError, + 5 => BlasStatus::ExecutionFailed, + 6 => BlasStatus::InternalError, + 7 => BlasStatus::NotSupported, + 8 => BlasStatus::ArchMismatch, + 9 => BlasStatus::HandleIsNullPointer, + 10 => BlasStatus::InvalidEnum, + _ => BlasStatus::Unknown, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct BlasError { + pub status: BlasStatus, + pub code: u32, +} + +impl BlasError { + pub fn new(code: u32) -> Self { + Self { + status: BlasStatus::from(code), + code, + } + } + + pub fn from_status(status: BlasStatus) -> Self { + Self { + status, + code: status as u32, + } + } +} + +impl StatusCode for BlasError { + fn is_success(&self) -> bool { + self.status == BlasStatus::Success + } + + fn code(&self) -> u32 { + self.code as u32 + } + + fn kind_str(&self) -> &'static str { + "HIPBLAS" + } + + fn status_str(&self) -> &'static str { + match self.status { + BlasStatus::Success => "Success", + BlasStatus::NotInitialized => "NotInitialized", + BlasStatus::AllocationFailed => "AllocationFailed", + BlasStatus::InvalidValue => "InvalidValue", + BlasStatus::MappingError => "MappingError", + BlasStatus::ExecutionFailed => "ExecutionFailed", + BlasStatus::InternalError => "InternalError", + BlasStatus::NotSupported => "NotSupported", + BlasStatus::ArchMismatch => "ArchMismatch", + BlasStatus::HandleIsNullPointer => "HandleIsNullPointer", + BlasStatus::InvalidEnum => "InvalidEnum", + BlasStatus::Unknown => "Unknown", + } + } +} + +pub type Result = std::result::Result; + +impl ResultExt for (T, u32) { + type Value = T; + fn to_result(self) -> Result { + let (value, status) = self; + (value, BlasError::new(status)).to_result() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_blas_status_from() { + assert_eq!(BlasStatus::from(0), BlasStatus::Success); + assert_eq!(BlasStatus::from(1), BlasStatus::NotInitialized); + assert_eq!(BlasStatus::from(2), BlasStatus::AllocationFailed); + assert_eq!(BlasStatus::from(3), BlasStatus::InvalidValue); + assert_eq!(BlasStatus::from(4), BlasStatus::MappingError); + assert_eq!(BlasStatus::from(5), BlasStatus::ExecutionFailed); + assert_eq!(BlasStatus::from(6), BlasStatus::InternalError); + assert_eq!(BlasStatus::from(7), BlasStatus::NotSupported); + assert_eq!(BlasStatus::from(8), BlasStatus::ArchMismatch); + assert_eq!(BlasStatus::from(9), BlasStatus::HandleIsNullPointer); + assert_eq!(BlasStatus::from(10), BlasStatus::InvalidEnum); + assert_eq!(BlasStatus::from(11), BlasStatus::Unknown); + assert_eq!(BlasStatus::from(999), BlasStatus::Unknown); + } + + #[test] + fn test_blas_error_new() { + let error = BlasError::new(3); + assert_eq!(error.status, BlasStatus::InvalidValue); + assert_eq!(error.code, 3); + } + + #[test] + fn test_blas_error_from_status() { + let error = BlasError::from_status(BlasStatus::Success); + assert_eq!(error.status, BlasStatus::Success); + assert_eq!(error.code, 0); + } + + #[test] + fn test_status_code_traits() { + let success = BlasError::new(0); + let error = BlasError::new(1); + + assert!(success.is_success()); + assert!(!error.is_success()); + + assert_eq!(success.code(), 0); + assert_eq!(error.code(), 1); + + assert_eq!(success.kind_str(), "HIPBLAS"); + assert_eq!(error.kind_str(), "HIPBLAS"); + } + + #[test] + fn test_result_ext() { + let success: Result = (42, 0).to_result(); + let error: Result = (42, 1).to_result(); + + assert!(success.is_ok()); + assert!(error.is_err()); + + assert_eq!(success.unwrap(), 42); + assert_eq!(error.unwrap_err().status, BlasStatus::NotInitialized); + } +} diff --git a/src/hipblas/types.rs b/src/hipblas/types.rs index bd7dc07..2d9f0f4 100644 --- a/src/hipblas/types.rs +++ b/src/hipblas/types.rs @@ -3,9 +3,9 @@ use crate::sys; #[repr(u32)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Operation { - None = 0, // HIPBLAS_OP_N - Transpose = 1, // HIPBLAS_OP_T - Conjugate = 2, // HIPBLAS_OP_C + None = 111, // HIPBLAS_OP_N, Operate with the matrix. + Transpose = 112, // HIPBLAS_OP_T + Conjugate = 113, // HIPBLAS_OP_C } impl From for sys::hipblasOperation_t { @@ -18,33 +18,111 @@ impl From for sys::hipblasOperation_t { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Status { Success = 0, - Handle = 1, - NotInitialized = 2, + NotInitialized = 1, + AllocationFailed = 2, InvalidValue = 3, - ArchMismatch = 4, - MappingError = 5, - ExecutionFailed = 6, - InternalError = 7, - NotSupported = 8, - MemoryError = 9, - AllocationFailed = 10, + MappingError = 4, + ExecutionFailed = 5, + InternalError = 6, + NotSupported = 7, + ArchMismatch = 8, + HandleIsNullPointer = 9, + InvalidEnum = 10, + Unknown = 11, // back-end returned an unsupported status code } impl From for Status { fn from(status: sys::hipblasStatus_t) -> Self { match status { 0 => Status::Success, - 1 => Status::Handle, - 2 => Status::NotInitialized, + 1 => Status::NotInitialized, + 2 => Status::AllocationFailed, 3 => Status::InvalidValue, - 4 => Status::ArchMismatch, - 5 => Status::MappingError, - 6 => Status::ExecutionFailed, - 7 => Status::InternalError, - 8 => Status::NotSupported, - 9 => Status::MemoryError, - 10 => Status::AllocationFailed, - _ => Status::InternalError, + 4 => Status::MappingError, + 5 => Status::ExecutionFailed, + 6 => Status::InternalError, + 7 => Status::NotSupported, + 8 => Status::ArchMismatch, + 9 => Status::HandleIsNullPointer, + 10 => Status::InvalidEnum, + _ => Status::Unknown, } } } + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct Complex32 { + inner: sys::hipblasComplex, +} + +impl Complex32 { + /// Creates a new complex number from real and imaginary parts + pub fn new(r: f32, i: f32) -> Self { + Self { + inner: sys::hipblasComplex { x: r, y: i }, + } + } + + /// Returns the real part + pub fn real(&self) -> f32 { + self.inner.x + } + + /// Returns the imaginary part + pub fn imag(&self) -> f32 { + self.inner.y + } + + /// Returns the complex conjugate + pub fn conj(&self) -> Self { + Self::new(self.real(), -self.imag()) + } + + /// Returns the magnitude (absolute value) of the complex number + pub fn abs(&self) -> f32 { + (self.real() * self.real() + self.imag() * self.imag()).sqrt() + } + + /// Returns the argument (phase) of the complex number in radians + pub fn arg(&self) -> f32 { + self.imag().atan2(self.real()) + } +} + +impl From for Complex32 { + fn from(c: sys::hipblasComplex) -> Self { + Self { inner: c } + } +} + +impl From for sys::hipblasComplex { + fn from(c: Complex32) -> Self { + c.inner + } +} + +impl Default for Complex32 { + fn default() -> Self { + Self::new(0.0, 0.0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_complex_creation() { + let c = Complex32::new(1.0, 2.0); + assert_eq!(c.real(), 1.0); + assert_eq!(c.imag(), 2.0); + } + + #[test] + fn test_complex_conjugate() { + let c = Complex32::new(1.0, 2.0); + let conj = c.conj(); + assert_eq!(conj.real(), 1.0); + assert_eq!(conj.imag(), -2.0); + } +} diff --git a/src/lib.rs b/src/lib.rs index 7a8a26e..7c105ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,9 @@ #![allow(non_upper_case_globals)] mod core; mod hipblas; +mod result; mod sys; pub use core::*; pub use hipblas::*; +pub use result::*; diff --git a/src/result.rs b/src/result.rs new file mode 100644 index 0000000..55e19c9 --- /dev/null +++ b/src/result.rs @@ -0,0 +1,144 @@ +use std::fmt; + +pub trait StatusCode: fmt::Debug { + fn is_success(&self) -> bool; + fn code(&self) -> u32; + fn kind_str(&self) -> &'static str; + fn status_str(&self) -> &'static str; + + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} status: {} (code: {})", + self.kind_str(), + self.status_str(), + self.code() + ) + } +} + +pub trait ResultExt { + type Value; + fn to_result(self) -> std::result::Result; +} + +impl ResultExt for (T, S) { + type Value = T; + + fn to_result(self) -> std::result::Result { + let (value, status) = self; + if status.is_success() { + Ok(value) + } else { + Err(status) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug)] + struct TestStatus { + success: bool, + error_code: u32, + } + + impl StatusCode for TestStatus { + fn is_success(&self) -> bool { + self.success + } + + fn code(&self) -> u32 { + self.error_code + } + + fn kind_str(&self) -> &'static str { + if self.success { + "Success" + } else { + "Failure" + } + } + + fn status_str(&self) -> &'static str { + if self.success { + "none" + } else { + "error" + } + } + } + + // Implement Display for TestStatus + #[test] + fn test_successful_result() { + let value = 42; + let status = TestStatus { + success: true, + error_code: 0, + }; + + let result = (value, status).to_result(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 42); + } + + #[test] + fn test_error_result() { + let value = 42; + let status = TestStatus { + success: false, + error_code: 1, + }; + + let result = (value, status).to_result(); + assert!(result.is_err()); + + let err = result.err().unwrap(); + assert_eq!(err.code(), 1); + assert!(!err.is_success()); + } + + #[test] + fn test_status_kind_str() { + let success_status = TestStatus { + success: true, + error_code: 0, + }; + assert_eq!(success_status.kind_str(), "Success"); + + let error_status = TestStatus { + success: false, + error_code: 1, + }; + assert_eq!(error_status.kind_str(), "Failure"); + } + + #[test] + fn test_status_display_format() { + let status = TestStatus { + success: false, + error_code: 500, + }; + assert_eq!(format!("{}", status), "Failure error: error (code: 500)"); + + let status = TestStatus { + success: true, + error_code: 0, + }; + assert_eq!(format!("{}", status), "Success error: none (code: 0)"); + } + impl fmt::Display for TestStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} error: {} (code: {})", + self.kind_str(), + if self.success { "none" } else { "error" }, + self.code() + ) + } + } +} From 94089e5a6c5a84fe248ed081ac435847f58f70e7 Mon Sep 17 00:00:00 2001 From: Anders Smedegaard Pedersen Date: Tue, 31 Dec 2024 10:15:12 +0100 Subject: [PATCH 11/11] 157 add blas call (#158) * add blas_call * ignore unused import warnings * make sys public --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/core/device.rs | 1 + src/core/device_types.rs | 1 + src/core/hip_call.rs | 9 ++- src/core/init.rs | 1 + src/core/mod.rs | 1 + src/core/stream.rs | 1 + src/hipblas/blas_call.rs | 170 +++++++++++++++++++++++++++++++++++++++ src/hipblas/gemm.rs | 4 +- src/hipblas/handle.rs | 4 +- src/hipblas/mod.rs | 3 + src/hipblas/result.rs | 8 +- src/lib.rs | 3 +- 12 files changed, 194 insertions(+), 12 deletions(-) create mode 100644 src/hipblas/blas_call.rs diff --git a/src/core/device.rs b/src/core/device.rs index 9cabea3..c04db61 100644 --- a/src/core/device.rs +++ b/src/core/device.rs @@ -1,3 +1,4 @@ +#[allow(unused_imports)] use super::result::{HipError, HipResult, HipStatus}; use super::{DeviceP2PAttribute, MemPool, PCIBusId}; use crate::result::ResultExt; diff --git a/src/core/device_types.rs b/src/core/device_types.rs index d7bae87..8fb1c62 100644 --- a/src/core/device_types.rs +++ b/src/core/device_types.rs @@ -1,3 +1,4 @@ +#[allow(unused_imports)] use super::result::{HipError, HipResult, HipStatus}; use crate::sys; use std::ffi::CStr; diff --git a/src/core/hip_call.rs b/src/core/hip_call.rs index 41ad169..3b19427 100644 --- a/src/core/hip_call.rs +++ b/src/core/hip_call.rs @@ -1,6 +1,9 @@ -use super::result::{HipResult, HipStatus}; -use crate::result::ResultExt; -use crate::sys; +#[allow(unused_imports)] +use { + super::result::{HipResult, HipStatus}, + crate::result::ResultExt, + crate::sys, +}; #[macro_export] macro_rules! hip_call { diff --git a/src/core/init.rs b/src/core/init.rs index 66b2a7e..fac84a1 100644 --- a/src/core/init.rs +++ b/src/core/init.rs @@ -1,3 +1,4 @@ +#[allow(unused_imports)] use super::result::{HipResult, HipStatus}; pub use crate::result::ResultExt; use crate::sys; diff --git a/src/core/mod.rs b/src/core/mod.rs index 197ae87..1421ae0 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -12,6 +12,7 @@ mod stream; pub use device::*; pub use device_types::*; pub use flags::*; +#[allow(unused_imports)] pub use hip_call::*; pub use init::*; pub use memory::*; diff --git a/src/core/stream.rs b/src/core/stream.rs index 171b546..b5c2d73 100644 --- a/src/core/stream.rs +++ b/src/core/stream.rs @@ -1,3 +1,4 @@ +#[allow(unused_imports)] use super::result::{HipResult, HipStatus}; use crate::result::ResultExt; use crate::sys; diff --git a/src/hipblas/blas_call.rs b/src/hipblas/blas_call.rs new file mode 100644 index 0000000..b824b33 --- /dev/null +++ b/src/hipblas/blas_call.rs @@ -0,0 +1,170 @@ +#[allow(unused_imports)] +use { + super::result::{BlasError, BlasResult}, + crate::result::ResultExt, + crate::sys, +}; + +/// Executes a BLAS Basic Linear Algebra Subprograms) function call and converts the result into a BlasResult. +/// +/// This macro wraps unsafe BLAS function calls and handles error checking by converting +/// the returned status code into a proper Result value. +/// +/// # Arguments +/// +/// * `$call` - The BLAS function call expression to execute +/// +/// # Returns +/// +/// * `BlasResult<()>` - Ok(()) if successful, Err(BlasError) if there was an error +/// +/// # Examples +/// +/// ```ignore +/// // this example will not compile, but give the basic idea of how +/// // to use `blas_call!` +/// +/// use hip_rs::sys; +/// use hip_rs::blas_call; +/// let mut result = 0.0f32; +/// let blas_result = blas_call!( +/// sys::hipblasSasum(handle.handle(), n, x.as_pointer(), 1, &mut result) +/// ); +/// ``` +/// + +#[macro_export] +macro_rules! blas_call { + ($call:expr) => {{ + let code: u32 = unsafe { $call }; + let result: BlasResult<()> = ((), code).to_result(); + result + }}; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{sys, BlasHandle, MemoryPointer}; + + fn setup_test_vector() -> (BlasHandle, MemoryPointer) { + let handle = BlasHandle::new().unwrap(); + + // Create a test vector with known values + let n = 5; + let vec = MemoryPointer::::alloc(n).unwrap(); + + // Initialize vector with test values + let host_data: Vec = vec![1.0, -2.0, 3.0, -4.0, 5.0]; + unsafe { + sys::hipMemcpy( + vec.as_pointer() as *mut std::ffi::c_void, + host_data.as_ptr() as *const std::ffi::c_void, + (n * std::mem::size_of::()) as usize, + sys::hipMemcpyKind_hipMemcpyHostToDevice, + ); + } + + (handle, vec) + } + + #[test] + fn test_isamin() { + let (handle, vec) = setup_test_vector(); + let mut result: i32 = 0; + + let blas_result = blas_call!(sys::hipblasIsamin( + handle.handle(), + 5, // n elements + vec.as_pointer(), + 1, // stride + &mut result, + )); + assert!(blas_result.is_ok()); + assert_eq!(result, 1); // 1.0 has maximum absolute value (1-based indexing) + } + + #[test] + fn test_isamax() { + let (handle, vec) = setup_test_vector(); + let mut result: i32 = 0; + + let blas_result = blas_call!(sys::hipblasIsamax( + handle.handle(), + 5, // n elements + vec.as_pointer(), + 1, // stride + &mut result, + )); + + assert!(blas_result.is_ok()); + assert_eq!(result, 5); // 5.0 has maximum absolute value (1-based indexing) + } + + #[test] + fn test_sasum() { + let (handle, vec) = setup_test_vector(); + let mut result: f32 = 0.0; + + let blas_result = blas_call!(sys::hipblasSasum( + handle.handle(), + 5, // n elements + vec.as_pointer(), + 1, // stride + &mut result, + )); + + assert!(blas_result.is_ok()); + // Expected sum of absolute values: |1.0| + |-2.0| + |3.0| + |-4.0| + |5.0| = 15.0 + assert_eq!(result, 15.0); + } + + #[test] + fn test_invalid_handle() { + let (_, vec) = setup_test_vector(); + let mut result: f32 = 0.0; + + let blas_result = blas_call!(sys::hipblasSasum( + std::ptr::null_mut(), // Invalid handle + 5, + vec.as_pointer(), + 1, + &mut result, + )); + + assert!(blas_result.is_err()); + } + + #[test] + fn test_invalid_pointer() { + let handle = BlasHandle::new().unwrap(); + let mut result: f32 = 0.0; + + let blas_result = blas_call!(sys::hipblasSasum( + handle.handle(), + 5, + std::ptr::null(), // Invalid pointer + 1, + &mut result, + )); + + assert!(blas_result.is_err()); + } + + #[test] + fn test_zero_length() { + let (handle, vec) = setup_test_vector(); + let mut result: f32 = 0.0; + + let blas_result = blas_call!(sys::hipblasSasum( + handle.handle(), + 0, // Zero length + vec.as_pointer(), + 1, + &mut result, + )); + + assert!(blas_result.is_ok()); + assert_eq!(result, 0.0); // Sum should be 0 + } +} diff --git a/src/hipblas/gemm.rs b/src/hipblas/gemm.rs index 7c251a8..f7dc25b 100644 --- a/src/hipblas/gemm.rs +++ b/src/hipblas/gemm.rs @@ -1,4 +1,4 @@ -use super::{BlasHandle, Operation, Result}; +use super::{BlasHandle, BlasResult, Operation}; use crate::result::ResultExt; use crate::Complex32; use crate::{sys, MemoryPointer}; @@ -184,7 +184,7 @@ pub fn gemm( beta: &T, c: &mut MemoryPointer, ldc: i32, -) -> Result<()> { +) -> BlasResult<()> { unsafe { let code = T::hipblas_gemm( handle.handle(), diff --git a/src/hipblas/handle.rs b/src/hipblas/handle.rs index f67e0a2..a85d070 100644 --- a/src/hipblas/handle.rs +++ b/src/hipblas/handle.rs @@ -1,4 +1,4 @@ -use super::Result; +use super::BlasResult; use crate::result::ResultExt; use crate::sys; use std::fmt; @@ -42,7 +42,7 @@ impl BlasHandle { /// /// let handle = BlasHandle::new().unwrap(); /// ``` - pub fn new() -> Result { + pub fn new() -> BlasResult { let mut handle = std::ptr::null_mut(); unsafe { let status = sys::hipblasCreate(&mut handle); diff --git a/src/hipblas/mod.rs b/src/hipblas/mod.rs index 31dc80e..1e34302 100644 --- a/src/hipblas/mod.rs +++ b/src/hipblas/mod.rs @@ -1,8 +1,11 @@ +mod blas_call; mod gemm; mod handle; mod result; mod types; +#[allow(unused_imports)] +pub use blas_call::*; pub use gemm::*; pub use handle::*; pub use result::*; diff --git a/src/hipblas/result.rs b/src/hipblas/result.rs index 98e4d19..9104158 100644 --- a/src/hipblas/result.rs +++ b/src/hipblas/result.rs @@ -89,11 +89,11 @@ impl StatusCode for BlasError { } } -pub type Result = std::result::Result; +pub type BlasResult = std::result::Result; impl ResultExt for (T, u32) { type Value = T; - fn to_result(self) -> Result { + fn to_result(self) -> BlasResult { let (value, status) = self; (value, BlasError::new(status)).to_result() } @@ -151,8 +151,8 @@ mod tests { #[test] fn test_result_ext() { - let success: Result = (42, 0).to_result(); - let error: Result = (42, 1).to_result(); + let success: BlasResult = (42, 0).to_result(); + let error: BlasResult = (42, 1).to_result(); assert!(success.is_ok()); assert!(error.is_err()); diff --git a/src/lib.rs b/src/lib.rs index 7c105ac..9ebb254 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,8 +2,9 @@ mod core; mod hipblas; mod result; -mod sys; +pub mod sys; pub use core::*; pub use hipblas::*; pub use result::*; +pub use sys::*;