Skip to content

Commit

Permalink
Merge pull request #1704 from emankov/HIPIFY
Browse files Browse the repository at this point in the history
[HIPIFY][rocBLAS] 64-bit functions support - Step 20
  • Loading branch information
emankov authored Oct 14, 2024
2 parents 1c16c89 + dbe9fd1 commit 9afb61a
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 16 deletions.
8 changes: 4 additions & 4 deletions bin/hipify-perl
Original file line number Diff line number Diff line change
Expand Up @@ -1714,6 +1714,7 @@ sub rocSubstitutions {
subst("cublasCtrmv_v2_64", "rocblas_ctrmv_64", "library");
subst("cublasCtrsm", "rocblas_ctrsm", "library");
subst("cublasCtrsmBatched", "rocblas_ctrsm_batched", "library");
subst("cublasCtrsmBatched_64", "rocblas_ctrsm_batched_64", "library");
subst("cublasCtrsm_64", "rocblas_ctrsm_64", "library");
subst("cublasCtrsm_v2", "rocblas_ctrsm", "library");
subst("cublasCtrsm_v2_64", "rocblas_ctrsm_64", "library");
Expand Down Expand Up @@ -1848,6 +1849,7 @@ sub rocSubstitutions {
subst("cublasDtrmv_v2_64", "rocblas_dtrmv_64", "library");
subst("cublasDtrsm", "rocblas_dtrsm", "library");
subst("cublasDtrsmBatched", "rocblas_dtrsm_batched", "library");
subst("cublasDtrsmBatched_64", "rocblas_dtrsm_batched_64", "library");
subst("cublasDtrsm_64", "rocblas_dtrsm_64", "library");
subst("cublasDtrsm_v2", "rocblas_dtrsm", "library");
subst("cublasDtrsm_v2_64", "rocblas_dtrsm_64", "library");
Expand Down Expand Up @@ -2066,6 +2068,7 @@ sub rocSubstitutions {
subst("cublasStrmv_v2_64", "rocblas_strmv_64", "library");
subst("cublasStrsm", "rocblas_strsm", "library");
subst("cublasStrsmBatched", "rocblas_strsm_batched", "library");
subst("cublasStrsmBatched_64", "rocblas_strsm_batched_64", "library");
subst("cublasStrsm_64", "rocblas_strsm_64", "library");
subst("cublasStrsm_v2", "rocblas_strsm", "library");
subst("cublasStrsm_v2_64", "rocblas_strsm_64", "library");
Expand Down Expand Up @@ -2223,6 +2226,7 @@ sub rocSubstitutions {
subst("cublasZtrmv_v2_64", "rocblas_ztrmv_64", "library");
subst("cublasZtrsm", "rocblas_ztrsm", "library");
subst("cublasZtrsmBatched", "rocblas_ztrsm_batched", "library");
subst("cublasZtrsmBatched_64", "rocblas_ztrsm_batched_64", "library");
subst("cublasZtrsm_64", "rocblas_ztrsm_64", "library");
subst("cublasZtrsm_v2", "rocblas_ztrsm", "library");
subst("cublasZtrsm_v2_64", "rocblas_ztrsm_64", "library");
Expand Down Expand Up @@ -12683,7 +12687,6 @@ sub warnRocOnlyUnsupportedFunctions {
my $k = 0;
foreach $func (
"cublasZtrttp",
"cublasZtrsmBatched_64",
"cublasZtrmm_v2_64",
"cublasZtrmm_64",
"cublasZtpttr",
Expand Down Expand Up @@ -12720,7 +12723,6 @@ sub warnRocOnlyUnsupportedFunctions {
"cublasSwapEx_64",
"cublasSwapEx",
"cublasStrttp",
"cublasStrsmBatched_64",
"cublasStrmm_v2_64",
"cublasStrmm_64",
"cublasStpttr",
Expand Down Expand Up @@ -12851,7 +12853,6 @@ sub warnRocOnlyUnsupportedFunctions {
"cublasGemmBatchedEx_64",
"cublasFree",
"cublasDtrttp",
"cublasDtrsmBatched_64",
"cublasDtrmm_v2_64",
"cublasDtrmm_64",
"cublasDtpttr",
Expand All @@ -12877,7 +12878,6 @@ sub warnRocOnlyUnsupportedFunctions {
"cublasDgeam_64",
"cublasDdgmm_64",
"cublasCtrttp",
"cublasCtrsmBatched_64",
"cublasCtrmm_v2_64",
"cublasCtrmm_64",
"cublasCtpttr",
Expand Down
8 changes: 4 additions & 4 deletions docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@
|`cublasCsyrkEx_64`|12.0| | | | | | | | | | | | | | | |
|`cublasCtpttr`| | | | | | | | | | | | | | | | |
|`cublasCtrsmBatched`| | | | |`hipblasCtrsmBatched_v2`|6.0.0| | | | |`rocblas_ctrsm_batched`|3.5.0| | | | |
|`cublasCtrsmBatched_64`|12.0| | | | | | | | | | | | | | | |
|`cublasCtrsmBatched_64`|12.0| | | | | | | | | |`rocblas_ctrsm_batched_64`|6.2.0| | | | |
|`cublasCtrttp`| | | | | | | | | | | | | | | | |
|`cublasDdgmm`| | | | |`hipblasDdgmm`|3.6.0| | | | |`rocblas_ddgmm`|3.5.0| | | | |
|`cublasDdgmm_64`|12.0| | | | | | | | | | | | | | | |
Expand All @@ -1268,7 +1268,7 @@
|`cublasDotcEx_64`|12.0| | | |`hipblasDotcEx_v2_64`|6.2.0| | | | |`rocblas_dotc_ex_64`|6.1.0| | | | |
|`cublasDtpttr`| | | | | | | | | | | | | | | | |
|`cublasDtrsmBatched`| | | | |`hipblasDtrsmBatched`|3.2.0| | | | |`rocblas_dtrsm_batched`|3.5.0| | | | |
|`cublasDtrsmBatched_64`|12.0| | | | | | | | | | | | | | | |
|`cublasDtrsmBatched_64`|12.0| | | | | | | | | |`rocblas_dtrsm_batched_64`|6.2.0| | | | |
|`cublasDtrttp`| | | | | | | | | | | | | | | | |
|`cublasGemmBatchedEx`|9.1| | | |`hipblasGemmBatchedEx_v2`|6.0.0| | | | |`rocblas_gemm_batched_ex`|3.5.0| | | | |
|`cublasGemmBatchedEx_64`|12.0| | | | | | | | | | | | | | | |
Expand Down Expand Up @@ -1302,7 +1302,7 @@
|`cublasSmatinvBatched`| | | | | | | | | | | | | | | | |
|`cublasStpttr`| | | | | | | | | | | | | | | | |
|`cublasStrsmBatched`| | | | |`hipblasStrsmBatched`|3.2.0| | | | |`rocblas_strsm_batched`|3.5.0| | | | |
|`cublasStrsmBatched_64`|12.0| | | | | | | | | | | | | | | |
|`cublasStrsmBatched_64`|12.0| | | | | | | | | |`rocblas_strsm_batched_64`|6.2.0| | | | |
|`cublasStrttp`| | | | | | | | | | | | | | | | |
|`cublasSwapEx`|10.1| | | | | | | | | | | | | | | |
|`cublasSwapEx_64`|12.0| | | | | | | | | | | | | | | |
Expand All @@ -1319,7 +1319,7 @@
|`cublasZmatinvBatched`| | | | | | | | | | | | | | | | |
|`cublasZtpttr`| | | | | | | | | | | | | | | | |
|`cublasZtrsmBatched`| | | | |`hipblasZtrsmBatched_v2`|6.0.0| | | | |`rocblas_ztrsm_batched`|3.5.0| | | | |
|`cublasZtrsmBatched_64`|12.0| | | | | | | | | | | | | | | |
|`cublasZtrsmBatched_64`|12.0| | | | | | | | | |`rocblas_ztrsm_batched_64`|6.2.0| | | | |
|`cublasZtrttp`| | | | | | | | | | | | | | | | |

## **9. BLASLt Function Reference**
Expand Down
8 changes: 4 additions & 4 deletions docs/tables/CUBLAS_API_supported_by_ROC.md
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@
|`cublasCsyrkEx_64`|12.0| | | | | | | | | |
|`cublasCtpttr`| | | | | | | | | | |
|`cublasCtrsmBatched`| | | | |`rocblas_ctrsm_batched`|3.5.0| | | | |
|`cublasCtrsmBatched_64`|12.0| | | | | | | | | |
|`cublasCtrsmBatched_64`|12.0| | | |`rocblas_ctrsm_batched_64`|6.2.0| | | | |
|`cublasCtrttp`| | | | | | | | | | |
|`cublasDdgmm`| | | | |`rocblas_ddgmm`|3.5.0| | | | |
|`cublasDdgmm_64`|12.0| | | | | | | | | |
Expand All @@ -1268,7 +1268,7 @@
|`cublasDotcEx_64`|12.0| | | |`rocblas_dotc_ex_64`|6.1.0| | | | |
|`cublasDtpttr`| | | | | | | | | | |
|`cublasDtrsmBatched`| | | | |`rocblas_dtrsm_batched`|3.5.0| | | | |
|`cublasDtrsmBatched_64`|12.0| | | | | | | | | |
|`cublasDtrsmBatched_64`|12.0| | | |`rocblas_dtrsm_batched_64`|6.2.0| | | | |
|`cublasDtrttp`| | | | | | | | | | |
|`cublasGemmBatchedEx`|9.1| | | |`rocblas_gemm_batched_ex`|3.5.0| | | | |
|`cublasGemmBatchedEx_64`|12.0| | | | | | | | | |
Expand Down Expand Up @@ -1302,7 +1302,7 @@
|`cublasSmatinvBatched`| | | | | | | | | | |
|`cublasStpttr`| | | | | | | | | | |
|`cublasStrsmBatched`| | | | |`rocblas_strsm_batched`|3.5.0| | | | |
|`cublasStrsmBatched_64`|12.0| | | | | | | | | |
|`cublasStrsmBatched_64`|12.0| | | |`rocblas_strsm_batched_64`|6.2.0| | | | |
|`cublasStrttp`| | | | | | | | | | |
|`cublasSwapEx`|10.1| | | | | | | | | |
|`cublasSwapEx_64`|12.0| | | | | | | | | |
Expand All @@ -1319,7 +1319,7 @@
|`cublasZmatinvBatched`| | | | | | | | | | |
|`cublasZtpttr`| | | | | | | | | | |
|`cublasZtrsmBatched`| | | | |`rocblas_ztrsm_batched`|3.5.0| | | | |
|`cublasZtrsmBatched_64`|12.0| | | | | | | | | |
|`cublasZtrsmBatched_64`|12.0| | | |`rocblas_ztrsm_batched_64`|6.2.0| | | | |
|`cublasZtrttp`| | | | | | | | | | |

## **9. BLASLt Function Reference**
Expand Down
12 changes: 8 additions & 4 deletions src/CUDA2HIP_BLAS_API_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,13 +590,13 @@ const std::map<llvm::StringRef, hipCounter> CUDA_BLAS_FUNCTION_MAP {

// TRSM - Batched Triangular Solver
{"cublasStrsmBatched", {"hipblasStrsmBatched", "rocblas_strsm_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT}},
{"cublasStrsmBatched_64", {"hipblasStrsmBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, UNSUPPORTED}},
{"cublasStrsmBatched_64", {"hipblasStrsmBatched_64", "rocblas_strsm_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, HIP_UNSUPPORTED}},
{"cublasDtrsmBatched", {"hipblasDtrsmBatched", "rocblas_dtrsm_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT}},
{"cublasDtrsmBatched_64", {"hipblasDtrsmBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, UNSUPPORTED}},
{"cublasDtrsmBatched_64", {"hipblasDtrsmBatched_64", "rocblas_dtrsm_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, HIP_UNSUPPORTED}},
{"cublasCtrsmBatched", {"hipblasCtrsmBatched_v2", "rocblas_ctrsm_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT}},
{"cublasCtrsmBatched_64", {"hipblasCtrsmBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, UNSUPPORTED}},
{"cublasCtrsmBatched_64", {"hipblasCtrsmBatched_64", "rocblas_ctrsm_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, HIP_UNSUPPORTED}},
{"cublasZtrsmBatched", {"hipblasZtrsmBatched_v2", "rocblas_ztrsm_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT}},
{"cublasZtrsmBatched_64", {"hipblasZtrsmBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, UNSUPPORTED}},
{"cublasZtrsmBatched_64", {"hipblasZtrsmBatched_64", "rocblas_ztrsm_batched_64", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, HIP_UNSUPPORTED}},

// MATINV - Batched
{"cublasSmatinvBatched", {"hipblasSmatinvBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_EXT, UNSUPPORTED}},
Expand Down Expand Up @@ -2407,6 +2407,10 @@ const std::map<llvm::StringRef, hipAPIversions> HIP_BLAS_FUNCTION_VER_MAP {
{"rocblas_dtrsm_64", {HIP_6020, HIP_0, HIP_0 }},
{"rocblas_ctrsm_64", {HIP_6020, HIP_0, HIP_0 }},
{"rocblas_ztrsm_64", {HIP_6020, HIP_0, HIP_0 }},
{"rocblas_strsm_batched_64", {HIP_6020, HIP_0, HIP_0 }},
{"rocblas_dtrsm_batched_64", {HIP_6020, HIP_0, HIP_0 }},
{"rocblas_ctrsm_batched_64", {HIP_6020, HIP_0, HIP_0 }},
{"rocblas_ztrsm_batched_64", {HIP_6020, HIP_0, HIP_0 }},
};

const std::map<llvm::StringRef, hipAPIChangedVersions> HIP_BLAS_FUNCTION_CHANGED_VER_MAP {
Expand Down
20 changes: 20 additions & 0 deletions tests/unit_tests/synthetic/libraries/cublas2rocblas_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3008,6 +3008,26 @@ int main() {
// CHECK-NEXT: blasStatus = rocblas_ztrsm_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &dcomplexa, &dcomplexA, lda_64, &dcomplexB, ldb_64);
blasStatus = cublasZtrsm_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &dcomplexa, &dcomplexA, lda_64, &dcomplexB, ldb_64);
blasStatus = cublasZtrsm_v2_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &dcomplexa, &dcomplexA, lda_64, &dcomplexB, ldb_64);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrsmBatched_64(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int64_t m, int64_t n, const float* alpha, const float* const A[], int64_t lda, float* const B[], int64_t ldb, int64_t batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_strsm_batched_64(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, int64_t m, int64_t n, const float* alpha, const float* const A[], int64_t lda, float* const B[], int64_t ldb, int64_t batch_count);
// CHECK: blasStatus = rocblas_strsm_batched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &fa, fAarray_const, lda_64, fBarray, ldb_64, batchCount_64);
blasStatus = cublasStrsmBatched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &fa, fAarray_const, lda_64, fBarray, ldb_64, batchCount_64);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrsmBatched_64(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int64_t m, int64_t n, const double* alpha, const double* const A[], int64_t lda, double* const B[], int64_t ldb, int64_t batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_dtrsm_batched_64(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, int64_t m, int64_t n, const double* alpha, const double* const A[], int64_t lda, double* const B[], int64_t ldb, int64_t batch_count);
// CHECK: blasStatus = rocblas_dtrsm_batched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &da, dAarray_const, lda_64, dBarray, ldb_64, batchCount_64);
blasStatus = cublasDtrsmBatched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &da, dAarray_const, lda_64, dBarray, ldb_64, batchCount_64);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrsmBatched_64(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int64_t m, int64_t n, const cuComplex* alpha, const cuComplex* const A[], int64_t lda, cuComplex* const B[], int64_t ldb, int64_t batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_ctrsm_batched_64(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, int64_t m, int64_t n, const rocblas_float_complex* alpha, const rocblas_float_complex* const A[], int64_t lda, rocblas_float_complex* const B[], int64_t ldb, int64_t batch_count);
// CHECK: blasStatus = rocblas_ctrsm_batched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &complexa, complexAarray_const, lda_64, complexBarray, ldb_64, batchCount_64);
blasStatus = cublasCtrsmBatched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &complexa, complexAarray_const, lda_64, complexBarray, ldb_64, batchCount_64);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrsmBatched_64(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int64_t m, int64_t n, const cuDoubleComplex* alpha, const cuDoubleComplex* const A[], int64_t lda, cuDoubleComplex* const B[], int64_t ldb, int64_t batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_ztrsm_batched_64(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, int64_t m, int64_t n, const rocblas_double_complex* alpha, const rocblas_double_complex* const A[], int64_t lda, rocblas_double_complex* const B[], int64_t ldb, int64_t batch_count);
// CHECK: blasStatus = rocblas_ztrsm_batched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &dcomplexa, dcomplexAarray_const, lda_64, dcomplexBarray, ldb_64, batchCount_64);
blasStatus = cublasZtrsmBatched_64(blasHandle, blasSideMode, blasFillMode, blasOperation, blasDiagType, m_64, n_64, &dcomplexa, dcomplexAarray_const, lda_64, dcomplexBarray, ldb_64, batchCount_64);
#endif

return 0;
Expand Down

0 comments on commit 9afb61a

Please sign in to comment.