From e15bf68d6a40b159fe6a410828c65979906e3851 Mon Sep 17 00:00:00 2001 From: lishicheng1996 <847223866@qq.com> Date: Fri, 21 Jul 2023 16:21:23 +0800 Subject: [PATCH] fix a bug caused by hipcc lambda value capture --- paddle/phi/kernels/funcs/blas/blas_impl.hip.h | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.hip.h b/paddle/phi/kernels/funcs/blas/blas_impl.hip.h index 6aa41e4f4a2b6..805a718ab85ed 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.hip.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.hip.h @@ -1173,6 +1173,56 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, }); } +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + float16 alpha, + const float16 *A, + const float16 *B, + float16 beta, + float16 *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + const int64_t strideC = M * N; + context_.CublasCall([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::rocblas_hgemm_strided_batched( + handle, + cuTransB, + cuTransA, + N, + M, + K, + reinterpret_cast(&alpha), + reinterpret_cast(B), + ldb, + strideB, + reinterpret_cast(A), + lda, + strideA, + reinterpret_cast(&beta), + reinterpret_cast(C), + ldc, + strideC, + batchCount)); + }); +} + // note(wangran16): unknown bug. parameters dislocation when calling // GEMM_STRIDED_BATCH and GEMM_STRIDED_BATCH template <>