Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bf16] pten matmul cuda kernel support bf16 #39485

Merged
merged 10 commits into from
Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/framework/pten_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,9 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
}

KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(op_proto_->type(), GetInputArgsNames(),
GetAttrsArgsNames(), GetOutputArgsNames());
return KernelSignature(pten::TransToPtenKernelName(op_proto_->type()),
GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames());
}

std::once_flag kernel_sig_map_init_flag;
Expand Down
243 changes: 243 additions & 0 deletions paddle/fluid/operators/math/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,102 @@ inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 8000
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta,
platform::bfloat16 *C) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"cublas fp16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A,
CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));

#endif // CUDA_VERSION >= 11000
}

template <>
template <>
inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N,
int K, platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A,
CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));

#endif // CUDA_VERSION >= 11000
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
Expand Down Expand Up @@ -1208,6 +1304,42 @@ inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
}
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMV(
bool trans_a, int M, int N, platform::bfloat16 alpha,
const platform::bfloat16 *A, const platform::bfloat16 *B,
platform::bfloat16 beta, platform::bfloat16 *C) const {
// Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve
// it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}

template <>
template <>
inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
// Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve
// it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
Expand Down Expand Up @@ -1306,6 +1438,91 @@ void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 9010
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb,
strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc,
strideC, batchCount, CUBLAS_COMPUTE_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
"11"));
#endif // CUDA_VERSION >= 11000
}

template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb,
strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc,
strideC, batchCount, CUBLAS_COMPUTE_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
"11"));
#endif // CUDA_VERSION >= 11000
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
Expand Down Expand Up @@ -1356,6 +1573,32 @@ inline void Blas<pten::GPUContext>::BatchedGEMM(
}
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}

template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
Expand Down
Loading