Skip to content

Commit

Permalink
Merge pull request #942 from emankov/HIPIFY
Browse files Browse the repository at this point in the history
[HIPIFY][rocBLAS][tests][fix] Use `const_cast` for some function arguments, which signatures changed in 9.2
  • Loading branch information
emankov authored Jul 9, 2023
2 parents 1147546 + b3ef021 commit 7e693f0
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 36 deletions.
50 changes: 32 additions & 18 deletions tests/unit_tests/synthetic/libraries/cublas2rocblas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ int main() {
float fresult = 0;

float** fAarray = 0;
const float** const fAarray_const = const_cast<const float**>(fAarray);
float** fBarray = 0;
const float** const fBarray_const = const_cast<const float**>(fBarray);
float** fCarray = 0;
float** fTauarray = 0;

Expand All @@ -255,12 +257,16 @@ int main() {
double dresult = 0;

double** dAarray = 0;
const double** const dAarray_const = const_cast<const double**>(dAarray);
double** dBarray = 0;
const double** const dBarray_const = const_cast<const double**>(dBarray);
double** dCarray = 0;
double** dTauarray = 0;

void** voidAarray = nullptr;
const void** const voidAarray_const = const_cast<const void**>(voidAarray);
void** voidBarray = nullptr;
const void** const voidBarray_const = const_cast<const void**>(voidBarray);
void** voidCarray = nullptr;

// TODO: #1281
Expand All @@ -283,20 +289,28 @@ int main() {
cuDoubleComplex dcomplex, dcomplexa, dcomplexA, dcomplexB, dcomplexC, dcomplexx, dcomplexy, dcomplexs, dcomplexb;

// CHECK: rocblas_float_complex** complexAarray = 0;
// CHECK: const rocblas_float_complex** const complexAarray_const = const_cast<const rocblas_float_complex**>(complexAarray);
// CHECK-NEXT: rocblas_float_complex** complexBarray = 0;
// CHECK: const rocblas_float_complex** const complexBarray_const = const_cast<const rocblas_float_complex**>(complexBarray);
// CHECK-NEXT: rocblas_float_complex** complexCarray = 0;
// CHECK-NEXT: rocblas_float_complex** complexTauarray = 0;
cuComplex** complexAarray = 0;
const cuComplex** const complexAarray_const = const_cast<const cuComplex**>(complexAarray);
cuComplex** complexBarray = 0;
const cuComplex** const complexBarray_const = const_cast<const cuComplex**>(complexBarray);
cuComplex** complexCarray = 0;
cuComplex** complexTauarray = 0;

// CHECK: rocblas_double_complex** dcomplexAarray = 0;
// CHECK: const rocblas_double_complex** const dcomplexAarray_const = const_cast<const rocblas_double_complex**>(dcomplexAarray);
// CHECK-NEXT: rocblas_double_complex** dcomplexBarray = 0;
// CHECK: const rocblas_double_complex** const dcomplexBarray_const = const_cast<const rocblas_double_complex**>(dcomplexBarray);
// CHECK-NEXT: rocblas_double_complex** dcomplexCarray = 0;
// CHECK-NEXT: rocblas_double_complex** dcomplexTauarray = 0;
cuDoubleComplex** dcomplexAarray = 0;
const cuDoubleComplex** const dcomplexAarray_const = const_cast<const cuDoubleComplex**>(dcomplexAarray);
cuDoubleComplex** dcomplexBarray = 0;
const cuDoubleComplex** const dcomplexBarray_const = const_cast<const cuDoubleComplex**>(dcomplexBarray);
cuDoubleComplex** dcomplexCarray = 0;
cuDoubleComplex** dcomplexTauarray = 0;

Expand Down Expand Up @@ -1194,14 +1208,14 @@ int main() {
// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, const float* const Aarray[], int lda, const float* const Barray[], int ldb, const float* beta, float* const Carray[], int ldc, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_sgemm_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const float* alpha, const float* const A[], rocblas_int lda, const float* const B[], rocblas_int ldb, const float* beta, float* const C[], rocblas_int ldc, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_sgemm_batched(blasHandle, transa, transb, m, n, k, &fa, fAarray, lda, fBarray, ldb, &fb, fCarray, ldc, batchCount);
blasStatus = cublasSgemmBatched(blasHandle, transa, transb, m, n, k, &fa, fAarray, lda, fBarray, ldb, &fb, fCarray, ldc, batchCount);
// CHECK: blasStatus = rocblas_sgemm_batched(blasHandle, transa, transb, m, n, k, &fa, fAarray_const, lda, fBarray_const, ldb, &fb, fCarray, ldc, batchCount);
blasStatus = cublasSgemmBatched(blasHandle, transa, transb, m, n, k, &fa, fAarray_const, lda, fBarray_const, ldb, &fb, fCarray, ldc, batchCount);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double* alpha, const double* const Aarray[], int lda, const double* const Barray[], int ldb, const double* beta, double* const Carray[], int ldc, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_dgemm_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const double* alpha, const double* const A[], rocblas_int lda, const double* const B[], rocblas_int ldb, const double* beta, double* const C[], rocblas_int ldc, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_dgemm_batched(blasHandle, transa, transb, m, n, k, &da, dAarray, lda, dBarray, ldb, &db, dCarray, ldc, batchCount);
blasStatus = cublasDgemmBatched(blasHandle, transa, transb, m, n, k, &da, dAarray, lda, dBarray, ldb, &db, dCarray, ldc, batchCount);
// CHECK: blasStatus = rocblas_dgemm_batched(blasHandle, transa, transb, m, n, k, &da, dAarray_const, lda, dBarray_const, ldb, &db, dCarray, ldc, batchCount);
blasStatus = cublasDgemmBatched(blasHandle, transa, transb, m, n, k, &da, dAarray_const, lda, dBarray_const, ldb, &db, dCarray, ldc, batchCount);

// TODO: #1281
// TODO: __half -> rocblas_half
Expand All @@ -1211,14 +1225,14 @@ int main() {
// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuComplex* alpha, const cuComplex* const Aarray[], int lda, const cuComplex* const Barray[], int ldb, const cuComplex* beta, cuComplex* const Carray[], int ldc, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_cgemm_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_float_complex* alpha, const rocblas_float_complex* const A[], rocblas_int lda, const rocblas_float_complex* const B[], rocblas_int ldb, const rocblas_float_complex* beta, rocblas_float_complex* const C[], rocblas_int ldc, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_cgemm_batched(blasHandle, transa, transb, m, n, k, &complexa, complexAarray, lda, complexBarray, ldb, &complexb, complexCarray, ldc, batchCount);
blasStatus = cublasCgemmBatched(blasHandle, transa, transb, m, n, k, &complexa, complexAarray, lda, complexBarray, ldb, &complexb, complexCarray, ldc, batchCount);
// CHECK: blasStatus = rocblas_cgemm_batched(blasHandle, transa, transb, m, n, k, &complexa, complexAarray_const, lda, complexBarray_const, ldb, &complexb, complexCarray, ldc, batchCount);
blasStatus = cublasCgemmBatched(blasHandle, transa, transb, m, n, k, &complexa, complexAarray_const, lda, complexBarray_const, ldb, &complexb, complexCarray, ldc, batchCount);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuDoubleComplex* alpha, const cuDoubleComplex* const Aarray[], int lda, const cuDoubleComplex* const Barray[], int ldb, const cuDoubleComplex* beta, cuDoubleComplex* const Carray[], int ldc, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_zgemm_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_double_complex* alpha, const rocblas_double_complex* const A[], rocblas_int lda, const rocblas_double_complex* const B[], rocblas_int ldb, const rocblas_double_complex* beta, rocblas_double_complex* const C[], rocblas_int ldc, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_zgemm_batched(blasHandle, transa, transb, m, n, k, &dcomplexa, dcomplexAarray, lda, dcomplexBarray, ldb, &dcomplexb, dcomplexCarray, ldc, batchCount);
blasStatus = cublasZgemmBatched(blasHandle, transa, transb, m, n, k, &dcomplexa, dcomplexAarray, lda, dcomplexBarray, ldb, &dcomplexb, dcomplexCarray, ldc, batchCount);
// CHECK: blasStatus = rocblas_zgemm_batched(blasHandle, transa, transb, m, n, k, &dcomplexa, dcomplexAarray_const, lda, dcomplexBarray_const, ldb, &dcomplexb, dcomplexCarray, ldc, batchCount);
blasStatus = cublasZgemmBatched(blasHandle, transa, transb, m, n, k, &dcomplexa, dcomplexAarray_const, lda, dcomplexBarray_const, ldb, &dcomplexb, dcomplexCarray, ldc, batchCount);

// TODO: #1281
// NOTE: void CUBLASWINAPI cublasSsyrk(char uplo, char trans, int n, int k, float alpha, const float* A, int lda, float beta, float* C, int ldc); is not supported by HIP
Expand Down Expand Up @@ -1465,26 +1479,26 @@ int main() {
// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasStrsmBatched(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const float* alpha, const float* const A[], int lda, float* const B[], int ldb, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_strsm_batched(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, rocblas_int m, rocblas_int n, const float* alpha, const float* const A[], rocblas_int lda, float* const B[], rocblas_int ldb, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_strsm_batched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &fa, fAarray, lda, fBarray, ldb, batchCount);
blasStatus = cublasStrsmBatched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &fa, fAarray, lda, fBarray, ldb, batchCount);
// CHECK: blasStatus = rocblas_strsm_batched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &fa, fAarray_const, lda, fBarray, ldb, batchCount);
blasStatus = cublasStrsmBatched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &fa, fAarray_const, lda, fBarray, ldb, batchCount);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDtrsmBatched(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const double* alpha, const double* const A[], int lda, double* const B[], int ldb, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_dtrsm_batched(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, rocblas_int m, rocblas_int n, const double* alpha, const double* const A[], rocblas_int lda, double* const B[], rocblas_int ldb, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_dtrsm_batched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &da, dAarray, lda, dBarray, ldb, batchCount);
blasStatus = cublasDtrsmBatched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &da, dAarray, lda, dBarray, ldb, batchCount);
// CHECK: blasStatus = rocblas_dtrsm_batched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &da, dAarray_const, lda, dBarray, ldb, batchCount);
blasStatus = cublasDtrsmBatched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &da, dAarray_const, lda, dBarray, ldb, batchCount);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCtrsmBatched(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const cuComplex* alpha, const cuComplex* const A[], int lda, cuComplex* const B[], int ldb, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_ctrsm_batched(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, rocblas_int m, rocblas_int n, const rocblas_float_complex* alpha, const rocblas_float_complex* const A[], rocblas_int lda, rocblas_float_complex* const B[], rocblas_int ldb, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_ctrsm_batched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &complexa, complexAarray, lda, complexBarray, ldb, batchCount);
blasStatus = cublasCtrsmBatched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &complexa, complexAarray, lda, complexBarray, ldb, batchCount);
// CHECK: blasStatus = rocblas_ctrsm_batched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &complexa, complexAarray_const, lda, complexBarray, ldb, batchCount);
blasStatus = cublasCtrsmBatched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &complexa, complexAarray_const, lda, complexBarray, ldb, batchCount);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasZtrsmBatched(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const cuDoubleComplex* alpha, const cuDoubleComplex* const A[], int lda, cuDoubleComplex* const B[], int ldb, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_ztrsm_batched(rocblas_handle handle, rocblas_side side, rocblas_fill uplo, rocblas_operation transA, rocblas_diagonal diag, rocblas_int m, rocblas_int n, const rocblas_double_complex* alpha, const rocblas_double_complex* const A[], rocblas_int lda, rocblas_double_complex* const B[], rocblas_int ldb, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_ztrsm_batched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &dcomplexa, dcomplexAarray, lda, dcomplexBarray, ldb, batchCount);
blasStatus = cublasZtrsmBatched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &dcomplexa, dcomplexAarray, lda, dcomplexBarray, ldb, batchCount);
// CHECK: blasStatus = rocblas_ztrsm_batched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &dcomplexa, dcomplexAarray_const, lda, dcomplexBarray, ldb, batchCount);
blasStatus = cublasZtrsmBatched(blasHandle, blasSideMode, blasFillMode, transa, blasDiagType, m, n, &dcomplexa, dcomplexAarray_const, lda, dcomplexBarray, ldb, batchCount);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSdgmm(cublasHandle_t handle, cublasSideMode_t mode, int m, int n, const float* A, int lda, const float* x, int incx, float* C, int ldc);
Expand Down Expand Up @@ -1657,8 +1671,8 @@ int main() {
// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGemmBatchedEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void* alpha, const void* const Aarray[], cudaDataType Atype, int lda, const void* const Barray[], cudaDataType Btype, int ldb, const void* beta, void* const Carray[], cudaDataType Ctype, int ldc, int batchCount, cublasComputeType_t computeType, cublasGemmAlgo_t algo);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_gemm_batched_ex(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha, const void* a, rocblas_datatype a_type, rocblas_int lda, const void* b, rocblas_datatype b_type, rocblas_int ldb, const void* beta, const void* c, rocblas_datatype c_type, rocblas_int ldc, void* d, rocblas_datatype d_type, rocblas_int ldd, rocblas_int batch_count, rocblas_datatype compute_type, rocblas_gemm_algo algo, int32_t solution_index, uint32_t flags);
// CHECK: blasStatus = rocblas_gemm_batched_ex(blasHandle, transa, transb, m, n, k, aptr, voidAarray, Atype, lda, voidBarray, Btype, ldb, bptr, voidCarray, Ctype, ldc, batchCount, computeType, blasGemmAlgo);
blasStatus = cublasGemmBatchedEx(blasHandle, transa, transb, m, n, k, aptr, voidAarray, Atype, lda, voidBarray, Btype, ldb, bptr, voidCarray, Ctype, ldc, batchCount, computeType, blasGemmAlgo);
// CHECK: blasStatus = rocblas_gemm_batched_ex(blasHandle, transa, transb, m, n, k, aptr, voidAarray_const, Atype, lda, voidBarray_const, Btype, ldb, bptr, voidCarray, Ctype, ldc, batchCount, computeType, blasGemmAlgo);
blasStatus = cublasGemmBatchedEx(blasHandle, transa, transb, m, n, k, aptr, voidAarray_const, Atype, lda, voidBarray_const, Btype, ldb, bptr, voidCarray, Ctype, ldc, batchCount, computeType, blasGemmAlgo);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasGemmStridedBatchedEx(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const void* alpha, const void* A, cudaDataType Atype, int lda, long long int strideA, const void* B, cudaDataType Btype, int ldb, long long int strideB, const void* beta, void* C, cudaDataType Ctype, int ldc, long long int strideC, int batchCount, cublasComputeType_t computeType, cublasGemmAlgo_t algo);
Expand Down
Loading

0 comments on commit 7e693f0

Please sign in to comment.