diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 91e6c078ecc45..cd122c5be6155 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -10,7 +10,10 @@ #include #include #include +#ifdef __HIP_PLATFORM_AMD__ +// for rocblas_initialize() #include "rocblas/rocblas.h" +#endif #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT @@ -2746,10 +2749,14 @@ void ggml_init_cublas() { static bool initialized = false; if (!initialized) { -#ifdef GGML_USE_HIPBLAS - rocblas_initialize(); - hipDeviceSynchronize(); + +#ifdef __HIP_PLATFORM_AMD__ + // Workaround for a rocBLAS bug when using multiple graphics cards: + // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346 + rocblas_initialize(); + CUDA_CHECK(cudaDeviceSynchronize()); #endif + CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0;