Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support bmm fp8 #469

Merged
merged 16 commits into from
Aug 26, 2024
Prev Previous commit
Next Next commit
support template
  • Loading branch information
zhyncs committed Aug 26, 2024
commit ca700211d73bd7357c5e3abf1e89722f2be887af
61 changes: 56 additions & 5 deletions include/flashinfer/bmm_fp8.cuh
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
#include <torch/extension.h>

#include <stdexcept>
#include <type_traits>

namespace flashinfer {

@@ -92,17 +93,40 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<cublasLtMatmulPrefere
}
};

void bmm_fp8_internal_cublaslt(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D,
int batch_size, int m, int n, int k) {
template <typename T>
cudaDataType_t get_cuda_data_type() {
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
return CUDA_R_8F_E4M3;
} else if constexpr (std::is_same_v<T, __nv_fp8_e5m2>) {
return CUDA_R_8F_E5M2;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return CUDA_R_16BF;
} else if constexpr (std::is_same_v<T, half>) {
return CUDA_R_16F;
} else {
throw std::runtime_error("Unsupported type");
}
}

template <typename AT, typename BT, typename DT>
void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size, int m, int n,
int k) {
auto matmul_desp = CuBlasLtMatmulDescriptor(CUBLAS_COMPUTE_32F, CUDA_R_32F);
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T);
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N);
int8_t fast_accum = 1;
matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fast_accum);

auto a_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, m, k, k, true);
auto b_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, k, n, k);
auto d_desp = CuBlasLtMatrixLayout(CUDA_R_16BF, m, n, m);
cudaDataType_t a_type = get_cuda_data_type<AT>();
cudaDataType_t b_type = get_cuda_data_type<BT>();
cudaDataType_t d_type = get_cuda_data_type<DT>();
if (std::is_same_v<AT, __nv_fp8_e5m2> && std::is_same_v<BT, __nv_fp8_e5m2>) {
throw std::runtime_error("Unsupported combination: both A and B are e5m2");
}

auto a_desp = CuBlasLtMatrixLayout(a_type, m, k, k, true);
auto b_desp = CuBlasLtMatrixLayout(b_type, k, n, k);
auto d_desp = CuBlasLtMatrixLayout(d_type, m, n, m);

if (batch_size > 1) {
int64_t stride_a = m * k;
@@ -141,6 +165,33 @@ void bmm_fp8_internal_cublaslt(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, _
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, at::cuda::blas::_cublasGetErrorEnum(status));
}

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>(
const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n,
int k);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, half>(const __nv_fp8_e4m3* A,
const __nv_fp8_e4m3* B,
half* D, int batch_size,
int m, int n, int k);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, __nv_bfloat16>(
const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, __nv_bfloat16* D, int batch_size, int m, int n,
int k);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, half>(const __nv_fp8_e4m3* A,
const __nv_fp8_e5m2* B,
half* D, int batch_size,
int m, int n, int k);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, __nv_bfloat16>(
const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n,
int k);

template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>(const __nv_fp8_e5m2* A,
const __nv_fp8_e4m3* B,
half* D, int batch_size,
int m, int n, int k);

} // namespace bmm_fp8
} // namespace flashinfer

25 changes: 17 additions & 8 deletions python/csrc/bmm_fp8.cu
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
#include <flashinfer/bmm_fp8.cuh>

#include "flashinfer_ops.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

@@ -34,9 +35,12 @@ void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D) {
TORCH_CHECK(A.size(2) == B.size(1), "Incompatible matrix sizes");
TORCH_CHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2),
"Result tensor has incorrect shape");
TORCH_CHECK(A.scalar_type() == torch::kFloat8_e4m3fn, "A must be Float8_e4m3fn");
TORCH_CHECK(B.scalar_type() == torch::kFloat8_e4m3fn, "B must be Float8_e4m3fn");
TORCH_CHECK(D.scalar_type() == torch::kBFloat16, "D must be BFloat16");
TORCH_CHECK(A.scalar_type() == torch::kFloat8_e4m3fn || A.scalar_type() == torch::kFloat8_e5m2,
"A must be Float8_e4m3fn or Float8_e5m2");
TORCH_CHECK(B.scalar_type() == torch::kFloat8_e4m3fn || B.scalar_type() == torch::kFloat8_e5m2,
"B must be Float8_e4m3fn or Float8_e5m2");
TORCH_CHECK(D.scalar_type() == torch::kBFloat16 || D.scalar_type() == torch::kHalf,
"D must be BFloat16 or Half");

auto batch_size = A.size(0);
auto m = A.size(1);
@@ -46,9 +50,14 @@ void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D) {
// PyTorch is row major by default. cuBLASLt is column major by default.
// We need row major D as expected.
// A ^ T * B = D, so D ^ T = B ^ T * A
if (D.scalar_type() == at::ScalarType::BFloat16) {
flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
static_cast<__nv_fp8_e4m3*>(B.data_ptr()), static_cast<__nv_fp8_e4m3*>(A.data_ptr()),
static_cast<__nv_bfloat16*>(D.data_ptr()), batch_size, n, m, k);
}
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(B.scalar_type(), b_type, [&] {
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(A.scalar_type(), a_type, [&] {
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(D.scalar_type(), d_type, [&] {
flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
static_cast<b_type*>(B.data_ptr()), static_cast<a_type*>(A.data_ptr()),
static_cast<d_type*>(D.data_ptr()), batch_size, n, m, k);
return true;
});
});
});
}