From dbe9fd1470d6d5776b9d39d4b688a8a8f73a53b3 Mon Sep 17 00:00:00 2001 From: Evgeny Mankov Date: Mon, 14 Oct 2024 21:42:45 +0100 Subject: [PATCH] [HIPIFY][rocBLAS] 64-bit functions support - Step 20 + `rocblas_(s|d|c|z)trsm_batched_64` support + Updated synthetic tests, the regenerated `hipify-perl`, and `BLAS` `CUDA2HIP` documentation --- bin/hipify-perl | 8 ++++---- .../CUBLAS_API_supported_by_HIP_and_ROC.md | 8 ++++---- docs/tables/CUBLAS_API_supported_by_ROC.md | 8 ++++---- src/CUDA2HIP_BLAS_API_functions.cpp | 12 +++++++---- .../synthetic/libraries/cublas2rocblas_v2.cu | 20 +++++++++++++++++++ 5 files changed, 40 insertions(+), 16 deletions(-) diff --git a/bin/hipify-perl b/bin/hipify-perl index c97c5533..0075ea16 100755 --- a/bin/hipify-perl +++ b/bin/hipify-perl @@ -1714,6 +1714,7 @@ sub rocSubstitutions { subst("cublasCtrmv_v2_64", "rocblas_ctrmv_64", "library"); subst("cublasCtrsm", "rocblas_ctrsm", "library"); subst("cublasCtrsmBatched", "rocblas_ctrsm_batched", "library"); + subst("cublasCtrsmBatched_64", "rocblas_ctrsm_batched_64", "library"); subst("cublasCtrsm_64", "rocblas_ctrsm_64", "library"); subst("cublasCtrsm_v2", "rocblas_ctrsm", "library"); subst("cublasCtrsm_v2_64", "rocblas_ctrsm_64", "library"); @@ -1848,6 +1849,7 @@ sub rocSubstitutions { subst("cublasDtrmv_v2_64", "rocblas_dtrmv_64", "library"); subst("cublasDtrsm", "rocblas_dtrsm", "library"); subst("cublasDtrsmBatched", "rocblas_dtrsm_batched", "library"); + subst("cublasDtrsmBatched_64", "rocblas_dtrsm_batched_64", "library"); subst("cublasDtrsm_64", "rocblas_dtrsm_64", "library"); subst("cublasDtrsm_v2", "rocblas_dtrsm", "library"); subst("cublasDtrsm_v2_64", "rocblas_dtrsm_64", "library"); @@ -2066,6 +2068,7 @@ sub rocSubstitutions { subst("cublasStrmv_v2_64", "rocblas_strmv_64", "library"); subst("cublasStrsm", "rocblas_strsm", "library"); subst("cublasStrsmBatched", "rocblas_strsm_batched", "library"); + subst("cublasStrsmBatched_64", "rocblas_strsm_batched_64", "library"); subst("cublasStrsm_64", "rocblas_strsm_64", "library"); subst("cublasStrsm_v2", "rocblas_strsm", "library"); subst("cublasStrsm_v2_64", "rocblas_strsm_64", "library"); @@ -2223,6 +2226,7 @@ sub rocSubstitutions { subst("cublasZtrmv_v2_64", "rocblas_ztrmv_64", "library"); subst("cublasZtrsm", "rocblas_ztrsm", "library"); subst("cublasZtrsmBatched", "rocblas_ztrsm_batched", "library"); + subst("cublasZtrsmBatched_64", "rocblas_ztrsm_batched_64", "library"); subst("cublasZtrsm_64", "rocblas_ztrsm_64", "library"); subst("cublasZtrsm_v2", "rocblas_ztrsm", "library"); subst("cublasZtrsm_v2_64", "rocblas_ztrsm_64", "library"); @@ -12683,7 +12687,6 @@ sub warnRocOnlyUnsupportedFunctions { my $k = 0; foreach $func ( "cublasZtrttp", - "cublasZtrsmBatched_64", "cublasZtrmm_v2_64", "cublasZtrmm_64", "cublasZtpttr", @@ -12720,7 +12723,6 @@ sub warnRocOnlyUnsupportedFunctions { "cublasSwapEx_64", "cublasSwapEx", "cublasStrttp", - "cublasStrsmBatched_64", "cublasStrmm_v2_64", "cublasStrmm_64", "cublasStpttr", @@ -12851,7 +12853,6 @@ sub warnRocOnlyUnsupportedFunctions { "cublasGemmBatchedEx_64", "cublasFree", "cublasDtrttp", - "cublasDtrsmBatched_64", "cublasDtrmm_v2_64", "cublasDtrmm_64", "cublasDtpttr", @@ -12877,7 +12878,6 @@ sub warnRocOnlyUnsupportedFunctions { "cublasDgeam_64", "cublasDdgmm_64", "cublasCtrttp", - "cublasCtrsmBatched_64", "cublasCtrmm_v2_64", "cublasCtrmm_64", "cublasCtpttr", diff --git a/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md b/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md index adc47104..5fdf8baf 100644 --- a/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md +++ b/docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md @@ -1250,7 +1250,7 @@ |`cublasCsyrkEx_64`|12.0| | | | | | | | | | | | | | | | |`cublasCtpttr`| | | | | | | | | | | | | | | | | |`cublasCtrsmBatched`| | | | |`hipblasCtrsmBatched_v2`|6.0.0| | | | |`rocblas_ctrsm_batched`|3.5.0| | | | | -|`cublasCtrsmBatched_64`|12.0| | | | | | | | | | | | | | | | +|`cublasCtrsmBatched_64`|12.0| | | | | | | | | |`rocblas_ctrsm_batched_64`|6.2.0| | | | | |`cublasCtrttp`| | | | | | | | | | | | | | | | | |`cublasDdgmm`| | | | |`hipblasDdgmm`|3.6.0| | | | |`rocblas_ddgmm`|3.5.0| | | | | |`cublasDdgmm_64`|12.0| | | | | | | | | | | | | | | | @@ -1268,7 +1268,7 @@ |`cublasDotcEx_64`|12.0| | | |`hipblasDotcEx_v2_64`|6.2.0| | | | |`rocblas_dotc_ex_64`|6.1.0| | | | | |`cublasDtpttr`| | | | | | | | | | | | | | | | | |`cublasDtrsmBatched`| | | | |`hipblasDtrsmBatched`|3.2.0| | | | |`rocblas_dtrsm_batched`|3.5.0| | | | | -|`cublasDtrsmBatched_64`|12.0| | | | | | | | | | | | | | | | +|`cublasDtrsmBatched_64`|12.0| | | | | | | | | |`rocblas_dtrsm_batched_64`|6.2.0| | | | | |`cublasDtrttp`| | | | | | | | | | | | | | | | | |`cublasGemmBatchedEx`|9.1| | | |`hipblasGemmBatchedEx_v2`|6.0.0| | | | |`rocblas_gemm_batched_ex`|3.5.0| | | | | |`cublasGemmBatchedEx_64`|12.0| | | | | | | | | | | | | | | | @@ -1302,7 +1302,7 @@ |`cublasSmatinvBatched`| | | | | | | | | | | | | | | | | |`cublasStpttr`| | | | | | | | | | | | | | | | | |`cublasStrsmBatched`| | | | |`hipblasStrsmBatched`|3.2.0| | | | |`rocblas_strsm_batched`|3.5.0| | | | | -|`cublasStrsmBatched_64`|12.0| | | | | | | | | | | | | | | | +|`cublasStrsmBatched_64`|12.0| | | | | | | | | |`rocblas_strsm_batched_64`|6.2.0| | | | | |`cublasStrttp`| | | | | | | | | | | | | | | | | |`cublasSwapEx`|10.1| | | | | | | | | | | | | | | | |`cublasSwapEx_64`|12.0| | | | | | | | | | | | | | | | @@ -1319,7 +1319,7 @@ |`cublasZmatinvBatched`| | | | | | | | | | | | | | | | | |`cublasZtpttr`| | | | | | | | | | | | | | | | | |`cublasZtrsmBatched`| | | | |`hipblasZtrsmBatched_v2`|6.0.0| | | | |`rocblas_ztrsm_batched`|3.5.0| | | | | -|`cublasZtrsmBatched_64`|12.0| | | | | | | | | | | | | | | | +|`cublasZtrsmBatched_64`|12.0| | | | | | | | | |`rocblas_ztrsm_batched_64`|6.2.0| | | | | |`cublasZtrttp`| | | | | | | | | | | | | | | | | ## **9. BLASLt Function Reference** diff --git a/docs/tables/CUBLAS_API_supported_by_ROC.md b/docs/tables/CUBLAS_API_supported_by_ROC.md index 1ad86206..3b98b6bc 100644 --- a/docs/tables/CUBLAS_API_supported_by_ROC.md +++ b/docs/tables/CUBLAS_API_supported_by_ROC.md @@ -1250,7 +1250,7 @@ |`cublasCsyrkEx_64`|12.0| | | | | | | | | | |`cublasCtpttr`| | | | | | | | | | | |`cublasCtrsmBatched`| | | | |`rocblas_ctrsm_batched`|3.5.0| | | | | -|`cublasCtrsmBatched_64`|12.0| | | | | | | | | | +|`cublasCtrsmBatched_64`|12.0| | | |`rocblas_ctrsm_batched_64`|6.2.0| | | | | |`cublasCtrttp`| | | | | | | | | | | |`cublasDdgmm`| | | | |`rocblas_ddgmm`|3.5.0| | | | | |`cublasDdgmm_64`|12.0| | | | | | | | | | @@ -1268,7 +1268,7 @@ |`cublasDotcEx_64`|12.0| | | |`rocblas_dotc_ex_64`|6.1.0| | | | | |`cublasDtpttr`| | | | | | | | | | | |`cublasDtrsmBatched`| | | | |`rocblas_dtrsm_batched`|3.5.0| | | | | -|`cublasDtrsmBatched_64`|12.0| | | | | | | | | | +|`cublasDtrsmBatched_64`|12.0| | | |`rocblas_dtrsm_batched_64`|6.2.0| | | | | |`cublasDtrttp`| | | | | | | | | | | |`cublasGemmBatchedEx`|9.1| | | |`rocblas_gemm_batched_ex`|3.5.0| | | | | |`cublasGemmBatchedEx_64`|12.0| | | | | | | | | | @@ -1302,7 +1302,7 @@ |`cublasSmatinvBatched`| | | | | | | | | | | |`cublasStpttr`| | | | | | | | | | | |`cublasStrsmBatched`| | | | |`rocblas_strsm_batched`|3.5.0| | | | | -|`cublasStrsmBatched_64`|12.0| | | | | | | | | | +|`cublasStrsmBatched_64`|12.0| | | |`rocblas_strsm_batched_64`|6.2.0| | | | | |`cublasStrttp`| | | | | | | | | | | |`cublasSwapEx`|10.1| | | | | | | | | | |`cublasSwapEx_64`|12.0| | | | | | | | | | @@ -1319,7 +1319,7 @@ |`cublasZmatinvBatched`| | | | | | | | | | | |`cublasZtpttr`| | | | | | | | | | | |`cublasZtrsmBatched`| | | | |`rocblas_ztrsm_batched`|3.5.0| | | | | -|`cublasZtrsmBatched_64`|12.0| | | | | | | | | | +|`cublasZtrsmBatched_64`|12.0| | | |`rocblas_ztrsm_batched_64`|6.2.0| | | | | |`cublasZtrttp`| | | | | | | | | | | ## **9. BLASLt Function Reference** diff --git a/src/CUDA2HIP_BLAS_API_functions.cpp b/src/CUDA2HIP_BLAS_API_functions.cpp index 1dee8c87..f1a096ba 100644 --- a/src/CUDA2HIP_BLAS_API_functions.cpp +++ b/src/CUDA2HIP_BLAS_API_functions.cpp @@ -590,13 +590,13 @@ const std::map CUDA_BLAS_FUNCTION_MAP { // TRSM - Batched Triangular Solver {"cublasStrsmBatched", {"hipblasStrsmBatched", "rocblas_strsm_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT}}, - {"cublasStrsmBatched_64", {"hipblasStrsmBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, UNSUPPORTED}}, + {"cublasStrsmBatched_64", {"hipblasStrsmBatched_64", "rocblas_strsm_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, HIP_UNSUPPORTED}}, {"cublasDtrsmBatched", {"hipblasDtrsmBatched", "rocblas_dtrsm_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT}}, - {"cublasDtrsmBatched_64", {"hipblasDtrsmBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, UNSUPPORTED}}, + {"cublasDtrsmBatched_64", {"hipblasDtrsmBatched_64", "rocblas_dtrsm_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, HIP_UNSUPPORTED}}, {"cublasCtrsmBatched", {"hipblasCtrsmBatched_v2", "rocblas_ctrsm_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT}}, - {"cublasCtrsmBatched_64", {"hipblasCtrsmBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, UNSUPPORTED}}, + {"cublasCtrsmBatched_64", {"hipblasCtrsmBatched_64", "rocblas_ctrsm_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, HIP_UNSUPPORTED}}, {"cublasZtrsmBatched", {"hipblasZtrsmBatched_v2", "rocblas_ztrsm_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT}}, - {"cublasZtrsmBatched_64", {"hipblasZtrsmBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, UNSUPPORTED}}, + {"cublasZtrsmBatched_64", {"hipblasZtrsmBatched_64", "rocblas_ztrsm_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, HIP_UNSUPPORTED}}, // MATINV - Batched {"cublasSmatinvBatched", {"hipblasSmatinvBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, UNSUPPORTED}}, @@ -2407,6 +2407,10 @@ const std::map HIP_BLAS_FUNCTION_VER_MAP { {"rocblas_dtrsm_64", {HIP_6020, HIP_0, HIP_0 }}, {"rocblas_ctrsm_64", {HIP_6020, HIP_0, HIP_0 }}, {"rocblas_ztrsm_64", {HIP_6020, HIP_0, HIP_0 }}, + {"rocblas_strsm_batched_64", {HIP_6020, HIP_0, HIP_0 }}, + {"rocblas_dtrsm_batched_64", {HIP_6020, HIP_0, HIP_0 }}, + {"rocblas_ctrsm_batched_64", {HIP_6020, HIP_0, HIP_0 }}, + {"rocblas_ztrsm_batched_64", {HIP_6020, HIP_0, HIP_0 }}, }; const std::map HIP_BLAS_FUNCTION_CHANGED_VER_MAP { diff --git a/tests/unit_tests/synthetic/libraries/cublas2rocblas_v2.cu b/tests/unit_tests/synthetic/libraries/cublas2rocblas_v2.cu index 3ef155d6..0b975ea5 100644 --- a/tests/unit_tests/synthetic/libraries/cublas2rocblas_v2.cu +++ b/tests/unit_tests/synthetic/libraries/cublas2rocblas_v2.cu @@ -3008,6 +3008,26 @@ int main() { // CHECK-NEXT: blasStatus = rocblas_ztrsm_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &dcomplexa, &dcomplexA, lda_64, &dcomplexB, ldb_64); blasStatus = cublasZtrsm_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &dcomplexa, &dcomplexA, lda_64, &dcomplexB, ldb_64); blasStatus = cublasZtrsm_v2_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &dcomplexa, &dcomplexA, lda_64, &dcomplexB, ldb_64); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrsmBatched_64(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int64_t m, int64_t n, const float* alpha, const float* const A[], int64_t lda, float* const B[], int64_t ldb, int64_t batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_strsm_batched_64(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, int64_t m, int64_t n, const float* alpha, const float* const A[], int64_t lda, float* const B[], int64_t ldb, int64_t batch_count); + // CHECK: blasStatus = rocblas_strsm_batched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &fa, fAarray_const, lda_64, fBarray, ldb_64, batchCount_64); + blasStatus = cublasStrsmBatched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &fa, fAarray_const, lda_64, fBarray, ldb_64, batchCount_64); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrsmBatched_64(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int64_t m, int64_t n, const double* alpha, const double* const A[], int64_t lda, double* const B[], int64_t ldb, int64_t batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_dtrsm_batched_64(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, int64_t m, int64_t n, const double* alpha, const double* const A[], int64_t lda, double* const B[], int64_t ldb, int64_t batch_count); + // CHECK: blasStatus = rocblas_dtrsm_batched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &da, dAarray_const, lda_64, dBarray, ldb_64, batchCount_64); + blasStatus = cublasDtrsmBatched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &da, dAarray_const, lda_64, dBarray, ldb_64, batchCount_64); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrsmBatched_64(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int64_t m, int64_t n, const cuComplex* alpha, const cuComplex* const A[], int64_t lda, cuComplex* const B[], int64_t ldb, int64_t batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_ctrsm_batched_64(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, int64_t m, int64_t n, const rocblas_float_complex* alpha, const rocblas_float_complex* const A[], int64_t lda, rocblas_float_complex* const B[], int64_t ldb, int64_t batch_count); + // CHECK: blasStatus = rocblas_ctrsm_batched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &complexa, complexAarray_const, lda_64, complexBarray, ldb_64, batchCount_64); + blasStatus = cublasCtrsmBatched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &complexa, complexAarray_const, lda_64, complexBarray, ldb_64, batchCount_64); + + // CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrsmBatched_64(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int64_t m, int64_t n, const cuDoubleComplex* alpha, const cuDoubleComplex* const A[], int64_t lda, cuDoubleComplex* const B[], int64_t ldb, int64_t batchCount); + // ROC: ROCBLAS_EXPORT rocblas_status rocblas_ztrsm_batched_64(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, int64_t m, int64_t n, const rocblas_double_complex* alpha, const rocblas_double_complex* const A[], int64_t lda, rocblas_double_complex* const B[], int64_t ldb, int64_t batch_count); + // CHECK: blasStatus = rocblas_ztrsm_batched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &dcomplexa, dcomplexAarray_const, lda_64, dcomplexBarray, ldb_64, batchCount_64); + blasStatus = cublasZtrsmBatched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &dcomplexa, dcomplexAarray_const, lda_64, dcomplexBarray, ldb_64, batchCount_64); #endif return 0;