From 86cb8b8fda43c71c4793d729364c7836fff14b3b Mon Sep 17 00:00:00 2001 From: zhyncs Date: Thu, 22 Aug 2024 14:42:49 +0000 Subject: [PATCH 01/16] init --- include/flashinfer/bmm_fp8.cuh | 114 +++++++++++++++++++++++++++++++++ python/csrc/bmm_fp8.cu | 53 +++++++++++++++ python/csrc/flashinfer_ops.cu | 1 + python/csrc/flashinfer_ops.h | 2 + python/flashinfer/__init__.py | 1 + python/flashinfer/bmm_fp8.py | 34 ++++++++++ python/setup.py | 1 + python/tests/test_bmm_fp8.py | 14 ++++ 8 files changed, 220 insertions(+) create mode 100644 include/flashinfer/bmm_fp8.cuh create mode 100644 python/csrc/bmm_fp8.cu create mode 100644 python/flashinfer/bmm_fp8.py create mode 100644 python/tests/test_bmm_fp8.py diff --git a/include/flashinfer/bmm_fp8.cuh b/include/flashinfer/bmm_fp8.cuh new file mode 100644 index 00000000..3312d93c --- /dev/null +++ b/include/flashinfer/bmm_fp8.cuh @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_BMM_FP8_CUH_ +#define FLASHINFER_BMM_FP8_CUH_ + +#include +#include +#include +#include +#include + +#include + +namespace flashinfer { + +namespace bmm_fp8 { + +template +struct CuBlasLtDeleter { + void operator()(T* x) { + if (x != nullptr) { + TORCH_CUDABLAS_CHECK(destructor(x)); + } + } +}; + +template +class CuBlasLtDescriptor { + public: + T* descriptor() const { return descriptor_.get(); } + T* descriptor() { return descriptor_.get(); } + + protected: + std::unique_ptr> descriptor_; +}; + +class CuBlasLtMatmulDescriptor + : public CuBlasLtDescriptor { + public: + CuBlasLtMatmulDescriptor(cublasComputeType_t compute_type, cudaDataType_t scale_type) { + cublasLtMatmulDesc_t raw_descriptor = nullptr; + TORCH_CUDABLAS_CHECK(cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) { + TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T))); + } +}; + +class CuBlasLtMatrixLayout + : public CuBlasLtDescriptor { + public: + CuBlasLtMatrixLayout(cudaDataType_t type, uint64_t rows, uint64_t cols, int64_t ld, + bool t = false) { + cublasLtMatrixLayout_t raw_descriptor = nullptr; + TORCH_CUDABLAS_CHECK( + cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) { + TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T))); + } +}; + +void bmm_fp8_internal_cublaslt(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, + int batch_size, int m, int n, int k) { + auto matmul_desp = CuBlasLtMatmulDescriptor(CUBLAS_COMPUTE_32F, CUDA_R_32F); + matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T); + matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N); + + auto a_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, m, k, k, true); + auto b_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, k, n, k); + auto d_desp = CuBlasLtMatrixLayout(CUDA_R_16BF, m, n, m); + + if (batch_size > 1) { + int64_t stride_a = m * k; + int64_t stride_b = k * n; + int64_t stride_d = m * n; + a_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size); + a_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_a); + b_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size); + b_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_b); + d_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size); + d_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_d); + } + + const float alpha = 1.0f; + const float beta = 0.0f; + cublasStatus_t status = cublasLtMatmul( + at::cuda::getCurrentCUDABlasLtHandle(), matmul_desp.descriptor(), &alpha, A, + a_desp.descriptor(), B, b_desp.descriptor(), &beta, nullptr, d_desp.descriptor(), D, + d_desp.descriptor(), nullptr, nullptr, 0, at::cuda::getCurrentCUDAStream()); + TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, at::cuda::blas::_cublasGetErrorEnum(status)); +} + +} // namespace bmm_fp8 +} // namespace flashinfer + +#endif // FLASHINFER_BMM_FP8_CUH_ diff --git a/python/csrc/bmm_fp8.cu b/python/csrc/bmm_fp8.cu new file mode 100644 index 00000000..65b0afb7 --- /dev/null +++ b/python/csrc/bmm_fp8.cu @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#include "flashinfer_ops.h" + +using namespace flashinfer; + +void bmm_fp8(const torch::Tensor& input, const torch::Tensor& weight, torch::Tensor& result) { + TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor"); + TORCH_CHECK(weight.is_cuda(), "Weight must be a CUDA tensor"); + TORCH_CHECK(result.is_cuda(), "Result must be a CUDA tensor"); + TORCH_CHECK(input.dim() == 3, "Expected 3D tensor for input"); + TORCH_CHECK(weight.dim() == 3, "Expected 3D tensor for weight"); + TORCH_CHECK(result.dim() == 3, "Expected 3D tensor for result"); + TORCH_CHECK(input.size(0) == weight.size(0) && input.size(0) == result.size(0), + "Batch sizes must match"); + TORCH_CHECK(input.size(2) == weight.size(1), "Incompatible matrix sizes"); + TORCH_CHECK(input.size(1) == result.size(1) && weight.size(2) == result.size(2), + "Result tensor has incorrect shape"); + TORCH_CHECK(input.scalar_type() == torch::kFloat8_e4m3fn, "input must be Float8_e4m3fn"); + TORCH_CHECK(weight.scalar_type() == torch::kFloat8_e4m3fn, "weight must be Float8_e4m3fn"); + TORCH_CHECK(result.scalar_type() == torch::kBFloat16, "Result must be BFloat16"); + + auto batch_size = input.size(0); + auto m = input.size(1); + auto k = input.size(2); + auto n = weight.size(2); + + if (result.scalar_type() == at::ScalarType::BFloat16) { + flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(static_cast<__nv_fp8_e4m3*>(input.data_ptr()), + static_cast<__nv_fp8_e4m3*>(weight.data_ptr()), + static_cast<__nv_bfloat16*>(result.data_ptr()), + batch_size, m, n, k); + } +} diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 6f8217d2..cb4343a1 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -48,6 +48,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); m.def("packbits", &packbits, "GPU packbits operator"); m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); + m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); py::class_(m, "CutlassSegmentGEMMPyTorchWrapper") .def(py::init()) .def("register_workspace", &CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer) diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 4efe1ca2..db9d0e0f 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -104,6 +104,8 @@ torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, torch::Tensor output_indptr, const std::string& bitorder); +void bmm_fp8(const torch::Tensor& input, const torch::Tensor& weight, torch::Tensor& result); + class CutlassSegmentGEMMPyTorchWrapper { public: void RegisterWorkspaceBuffer(torch::Tensor workspace_buffer); diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index cce14719..eb3b4a48 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -28,6 +28,7 @@ single_decode_with_kv_cache, ) from .activation import gelu_tanh_and_mul, silu_and_mul +from .bmm_fp8 import bmm_fp8 from .group_gemm import SegmentGEMMWrapper from .norm import fused_add_rmsnorm, rmsnorm from .page import append_paged_kv_cache diff --git a/python/flashinfer/bmm_fp8.py b/python/flashinfer/bmm_fp8.py new file mode 100644 index 00000000..87439002 --- /dev/null +++ b/python/flashinfer/bmm_fp8.py @@ -0,0 +1,34 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch + +# mypy: disable-error-code="attr-defined" +try: + from . import _kernels +except ImportError as e: + import logging + import os + + if os.environ.get("BUILD_DOC", "0") == "1": + _kernels = None + logging.warning("Kernels are not loaded in documentation build mode.") + else: + raise e + + +def bmm_fp8(input: torch.Tensor, weight: torch.Tensor, result: torch.Tensor): + _kernels.bmm_fp8(input, weight, result) diff --git a/python/setup.py b/python/setup.py index a23c6ac3..2fd605be 100644 --- a/python/setup.py +++ b/python/setup.py @@ -344,6 +344,7 @@ def __init__(self, *args, **kwargs) -> None: "csrc/rope.cu", "csrc/group_gemm.cu", "csrc/quantization.cu", + "csrc/bmm_fp8.cu", ], include_dirs=include_dirs, extra_compile_args=extra_compile_args, diff --git a/python/tests/test_bmm_fp8.py b/python/tests/test_bmm_fp8.py new file mode 100644 index 00000000..92b194db --- /dev/null +++ b/python/tests/test_bmm_fp8.py @@ -0,0 +1,14 @@ +import torch +from flashinfer import bmm_fp8 + +input = torch.randn([16, 64, 48], device="cuda", dtype=torch.bfloat16) +# transpose, cuBLASLt row major column major +input_fp8 = input.to(torch.float8_e4m3fn).transpose(-1, -2) +mat2 = torch.randn([16, 64, 80], device="cuda", dtype=torch.bfloat16) +mat2_fp8 = mat2.to(torch.float8_e4m3fn) + +res = torch.empty([16, 48, 80], device="cuda", dtype=torch.bfloat16) + +bmm_fp8(input_fp8, mat2_fp8, res) + +print(res) From 3db2db46b0b7623788534ef0633dad9d5fbbdbc5 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 25 Aug 2024 10:26:51 +0000 Subject: [PATCH 02/16] fix --- include/flashinfer/bmm_fp8.cuh | 10 +++++----- python/csrc/bmm_fp8.cu | 6 +++--- python/tests/test_bmm_fp8.py | 19 +++++++++++-------- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/include/flashinfer/bmm_fp8.cuh b/include/flashinfer/bmm_fp8.cuh index 3312d93c..02f49021 100644 --- a/include/flashinfer/bmm_fp8.cuh +++ b/include/flashinfer/bmm_fp8.cuh @@ -83,7 +83,7 @@ void bmm_fp8_internal_cublaslt(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, _ matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T); matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N); - auto a_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, m, k, k, true); + auto a_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, k, m, k); auto b_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, k, n, k); auto d_desp = CuBlasLtMatrixLayout(CUDA_R_16BF, m, n, m); @@ -101,10 +101,10 @@ void bmm_fp8_internal_cublaslt(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, _ const float alpha = 1.0f; const float beta = 0.0f; - cublasStatus_t status = cublasLtMatmul( - at::cuda::getCurrentCUDABlasLtHandle(), matmul_desp.descriptor(), &alpha, A, - a_desp.descriptor(), B, b_desp.descriptor(), &beta, nullptr, d_desp.descriptor(), D, - d_desp.descriptor(), nullptr, nullptr, 0, at::cuda::getCurrentCUDAStream()); + cublasStatus_t status = + cublasLtMatmul(at::cuda::getCurrentCUDABlasLtHandle(), matmul_desp.descriptor(), &alpha, A, + a_desp.descriptor(), B, b_desp.descriptor(), &beta, D, d_desp.descriptor(), D, + d_desp.descriptor(), nullptr, nullptr, 0, at::cuda::getCurrentCUDAStream()); TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, at::cuda::blas::_cublasGetErrorEnum(status)); } diff --git a/python/csrc/bmm_fp8.cu b/python/csrc/bmm_fp8.cu index 65b0afb7..0b06cb01 100644 --- a/python/csrc/bmm_fp8.cu +++ b/python/csrc/bmm_fp8.cu @@ -45,9 +45,9 @@ void bmm_fp8(const torch::Tensor& input, const torch::Tensor& weight, torch::Ten auto n = weight.size(2); if (result.scalar_type() == at::ScalarType::BFloat16) { - flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(static_cast<__nv_fp8_e4m3*>(input.data_ptr()), - static_cast<__nv_fp8_e4m3*>(weight.data_ptr()), + flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(static_cast<__nv_fp8_e4m3*>(weight.data_ptr()), + static_cast<__nv_fp8_e4m3*>(input.data_ptr()), static_cast<__nv_bfloat16*>(result.data_ptr()), - batch_size, m, n, k); + batch_size, n, m, k); } } diff --git a/python/tests/test_bmm_fp8.py b/python/tests/test_bmm_fp8.py index 92b194db..72a06a18 100644 --- a/python/tests/test_bmm_fp8.py +++ b/python/tests/test_bmm_fp8.py @@ -1,14 +1,17 @@ import torch +import numpy as np from flashinfer import bmm_fp8 -input = torch.randn([16, 64, 48], device="cuda", dtype=torch.bfloat16) -# transpose, cuBLASLt row major column major -input_fp8 = input.to(torch.float8_e4m3fn).transpose(-1, -2) -mat2 = torch.randn([16, 64, 80], device="cuda", dtype=torch.bfloat16) -mat2_fp8 = mat2.to(torch.float8_e4m3fn) - -res = torch.empty([16, 48, 80], device="cuda", dtype=torch.bfloat16) +input = torch.randn([1, 48, 64], device="cuda", dtype=torch.bfloat16) +input_fp8 = input.to(torch.float8_e4m3fn) +mat2 = torch.randn([1, 64, 80], device="cuda", dtype=torch.bfloat16) +mat2_fp8 = mat2.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous() +mat2_fp8 = mat2_fp8.transpose(-1, -2) +res = torch.empty([1, 48, 80], device="cuda", dtype=torch.bfloat16) bmm_fp8(input_fp8, mat2_fp8, res) +res_bf16 = input @ mat2 -print(res) +np.testing.assert_allclose( + res.float().cpu().numpy(), res_bf16.float().cpu().numpy(), rtol=1e-1, atol=1e-1 +) From d7799c667868e0806dfab98bb55d425b96552fc9 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 25 Aug 2024 11:26:31 +0000 Subject: [PATCH 03/16] kernel is right --- include/flashinfer/bmm_fp8.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/flashinfer/bmm_fp8.cuh b/include/flashinfer/bmm_fp8.cuh index 02f49021..3312d93c 100644 --- a/include/flashinfer/bmm_fp8.cuh +++ b/include/flashinfer/bmm_fp8.cuh @@ -83,7 +83,7 @@ void bmm_fp8_internal_cublaslt(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, _ matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T); matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N); - auto a_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, k, m, k); + auto a_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, m, k, k, true); auto b_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, k, n, k); auto d_desp = CuBlasLtMatrixLayout(CUDA_R_16BF, m, n, m); @@ -101,10 +101,10 @@ void bmm_fp8_internal_cublaslt(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, _ const float alpha = 1.0f; const float beta = 0.0f; - cublasStatus_t status = - cublasLtMatmul(at::cuda::getCurrentCUDABlasLtHandle(), matmul_desp.descriptor(), &alpha, A, - a_desp.descriptor(), B, b_desp.descriptor(), &beta, D, d_desp.descriptor(), D, - d_desp.descriptor(), nullptr, nullptr, 0, at::cuda::getCurrentCUDAStream()); + cublasStatus_t status = cublasLtMatmul( + at::cuda::getCurrentCUDABlasLtHandle(), matmul_desp.descriptor(), &alpha, A, + a_desp.descriptor(), B, b_desp.descriptor(), &beta, nullptr, d_desp.descriptor(), D, + d_desp.descriptor(), nullptr, nullptr, 0, at::cuda::getCurrentCUDAStream()); TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, at::cuda::blas::_cublasGetErrorEnum(status)); } From 133eb2e00de02bef4974727e932a468a0d4f6310 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 25 Aug 2024 11:48:24 +0000 Subject: [PATCH 04/16] add comment --- python/csrc/bmm_fp8.cu | 3 +++ python/tests/test_bmm_fp8.py | 8 +++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/csrc/bmm_fp8.cu b/python/csrc/bmm_fp8.cu index 0b06cb01..97e7e33d 100644 --- a/python/csrc/bmm_fp8.cu +++ b/python/csrc/bmm_fp8.cu @@ -44,6 +44,9 @@ void bmm_fp8(const torch::Tensor& input, const torch::Tensor& weight, torch::Ten auto k = input.size(2); auto n = weight.size(2); + // PyTorch is row major by default. cuBLASLt is column major by default. + // We need row major result as expected. + // A ^ T * B = D, so D ^ T = B ^ T * A if (result.scalar_type() == at::ScalarType::BFloat16) { flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(static_cast<__nv_fp8_e4m3*>(weight.data_ptr()), static_cast<__nv_fp8_e4m3*>(input.data_ptr()), diff --git a/python/tests/test_bmm_fp8.py b/python/tests/test_bmm_fp8.py index 72a06a18..ba8cd268 100644 --- a/python/tests/test_bmm_fp8.py +++ b/python/tests/test_bmm_fp8.py @@ -2,13 +2,15 @@ import numpy as np from flashinfer import bmm_fp8 -input = torch.randn([1, 48, 64], device="cuda", dtype=torch.bfloat16) +input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) input_fp8 = input.to(torch.float8_e4m3fn) -mat2 = torch.randn([1, 64, 80], device="cuda", dtype=torch.bfloat16) +mat2 = torch.randn([16, 64, 80], device="cuda", dtype=torch.bfloat16) +# mat2 row major -> column major mat2_fp8 = mat2.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous() +# make original shape unchanged mat2_fp8 = mat2_fp8.transpose(-1, -2) -res = torch.empty([1, 48, 80], device="cuda", dtype=torch.bfloat16) +res = torch.empty([16, 48, 80], device="cuda", dtype=torch.bfloat16) bmm_fp8(input_fp8, mat2_fp8, res) res_bf16 = input @ mat2 From a4e4d5c05bafc4597ca707aa9348bc976d7f9cf9 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 25 Aug 2024 11:57:30 +0000 Subject: [PATCH 05/16] use torch --- python/tests/test_bmm_fp8.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tests/test_bmm_fp8.py b/python/tests/test_bmm_fp8.py index ba8cd268..c8cbb06d 100644 --- a/python/tests/test_bmm_fp8.py +++ b/python/tests/test_bmm_fp8.py @@ -1,5 +1,4 @@ import torch -import numpy as np from flashinfer import bmm_fp8 input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) @@ -14,6 +13,6 @@ bmm_fp8(input_fp8, mat2_fp8, res) res_bf16 = input @ mat2 -np.testing.assert_allclose( - res.float().cpu().numpy(), res_bf16.float().cpu().numpy(), rtol=1e-1, atol=1e-1 +torch.testing.assert_close( + res.float().cpu(), res_bf16.float().cpu(), rtol=1e-1, atol=1e-1 ) From 3d30ad6a500c73e4edd20da595c4d398ac375d3b Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 25 Aug 2024 12:07:24 +0000 Subject: [PATCH 06/16] upd --- python/csrc/bmm_fp8.cu | 46 ++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/python/csrc/bmm_fp8.cu b/python/csrc/bmm_fp8.cu index 97e7e33d..21a133a7 100644 --- a/python/csrc/bmm_fp8.cu +++ b/python/csrc/bmm_fp8.cu @@ -23,34 +23,32 @@ using namespace flashinfer; -void bmm_fp8(const torch::Tensor& input, const torch::Tensor& weight, torch::Tensor& result) { - TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor"); - TORCH_CHECK(weight.is_cuda(), "Weight must be a CUDA tensor"); - TORCH_CHECK(result.is_cuda(), "Result must be a CUDA tensor"); - TORCH_CHECK(input.dim() == 3, "Expected 3D tensor for input"); - TORCH_CHECK(weight.dim() == 3, "Expected 3D tensor for weight"); - TORCH_CHECK(result.dim() == 3, "Expected 3D tensor for result"); - TORCH_CHECK(input.size(0) == weight.size(0) && input.size(0) == result.size(0), - "Batch sizes must match"); - TORCH_CHECK(input.size(2) == weight.size(1), "Incompatible matrix sizes"); - TORCH_CHECK(input.size(1) == result.size(1) && weight.size(2) == result.size(2), +void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D) { + TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor"); + TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor"); + TORCH_CHECK(D.is_cuda(), "D must be a CUDA tensor"); + TORCH_CHECK(A.dim() == 3, "Expected 3D tensor for A"); + TORCH_CHECK(B.dim() == 3, "Expected 3D tensor for B"); + TORCH_CHECK(D.dim() == 3, "Expected 3D tensor for D"); + TORCH_CHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0), "Batch sizes must match"); + TORCH_CHECK(A.size(2) == B.size(1), "Incompatible matrix sizes"); + TORCH_CHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2), "Result tensor has incorrect shape"); - TORCH_CHECK(input.scalar_type() == torch::kFloat8_e4m3fn, "input must be Float8_e4m3fn"); - TORCH_CHECK(weight.scalar_type() == torch::kFloat8_e4m3fn, "weight must be Float8_e4m3fn"); - TORCH_CHECK(result.scalar_type() == torch::kBFloat16, "Result must be BFloat16"); + TORCH_CHECK(A.scalar_type() == torch::kFloat8_e4m3fn, "A must be Float8_e4m3fn"); + TORCH_CHECK(B.scalar_type() == torch::kFloat8_e4m3fn, "B must be Float8_e4m3fn"); + TORCH_CHECK(D.scalar_type() == torch::kBFloat16, "D must be BFloat16"); - auto batch_size = input.size(0); - auto m = input.size(1); - auto k = input.size(2); - auto n = weight.size(2); + auto batch_size = A.size(0); + auto m = A.size(1); + auto k = A.size(2); + auto n = B.size(2); // PyTorch is row major by default. cuBLASLt is column major by default. - // We need row major result as expected. + // We need row major D as expected. // A ^ T * B = D, so D ^ T = B ^ T * A - if (result.scalar_type() == at::ScalarType::BFloat16) { - flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(static_cast<__nv_fp8_e4m3*>(weight.data_ptr()), - static_cast<__nv_fp8_e4m3*>(input.data_ptr()), - static_cast<__nv_bfloat16*>(result.data_ptr()), - batch_size, n, m, k); + if (D.scalar_type() == at::ScalarType::BFloat16) { + flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt( + static_cast<__nv_fp8_e4m3*>(B.data_ptr()), static_cast<__nv_fp8_e4m3*>(A.data_ptr()), + static_cast<__nv_bfloat16*>(D.data_ptr()), batch_size, n, m, k); } } From 091fd1cbb8b7ed4a5a93e5f2275bb80deefe3e5d Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 25 Aug 2024 12:16:23 +0000 Subject: [PATCH 07/16] enable fast_accum --- include/flashinfer/bmm_fp8.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/flashinfer/bmm_fp8.cuh b/include/flashinfer/bmm_fp8.cuh index 3312d93c..5ef311ad 100644 --- a/include/flashinfer/bmm_fp8.cuh +++ b/include/flashinfer/bmm_fp8.cuh @@ -82,6 +82,8 @@ void bmm_fp8_internal_cublaslt(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, _ auto matmul_desp = CuBlasLtMatmulDescriptor(CUBLAS_COMPUTE_32F, CUDA_R_32F); matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T); matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N); + int8_t fast_accum = 1; + matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fast_accum); auto a_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, m, k, k, true); auto b_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, k, n, k); From d0f6b287b2ed4e1c37634365d3d3ae65699c1d8f Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 25 Aug 2024 12:35:49 +0000 Subject: [PATCH 08/16] add workspace --- include/flashinfer/bmm_fp8.cuh | 37 +++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/include/flashinfer/bmm_fp8.cuh b/include/flashinfer/bmm_fp8.cuh index 5ef311ad..8baa8bad 100644 --- a/include/flashinfer/bmm_fp8.cuh +++ b/include/flashinfer/bmm_fp8.cuh @@ -77,6 +77,21 @@ class CuBlasLtMatrixLayout } }; +class CuBlasLtMatmulPreference : public CuBlasLtDescriptor { + public: + CuBlasLtMatmulPreference() { + cublasLtMatmulPreference_t raw_descriptor = nullptr; + TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) { + TORCH_CUDABLAS_CHECK( + ::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T))); + } +}; + void bmm_fp8_internal_cublaslt(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n, int k) { auto matmul_desp = CuBlasLtMatmulDescriptor(CUBLAS_COMPUTE_32F, CUDA_R_32F); @@ -101,12 +116,28 @@ void bmm_fp8_internal_cublaslt(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, _ d_desp.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_d); } + CuBlasLtMatmulPreference preference; + size_t workspace_size = 1024 * 1024; // 1 MiB + preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size); + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto workspace = allocator.allocate(workspace_size); + cublasLtMatmulHeuristicResult_t heuristic_result = {}; + int returned_result = 0; + auto lt_handle = at::cuda::getCurrentCUDABlasLtHandle(); + TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( + lt_handle, matmul_desp.descriptor(), a_desp.descriptor(), b_desp.descriptor(), + d_desp.descriptor(), d_desp.descriptor(), preference.descriptor(), 1, &heuristic_result, + &returned_result)); + if (returned_result == 0) { + TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED); + } + const float alpha = 1.0f; const float beta = 0.0f; cublasStatus_t status = cublasLtMatmul( - at::cuda::getCurrentCUDABlasLtHandle(), matmul_desp.descriptor(), &alpha, A, - a_desp.descriptor(), B, b_desp.descriptor(), &beta, nullptr, d_desp.descriptor(), D, - d_desp.descriptor(), nullptr, nullptr, 0, at::cuda::getCurrentCUDAStream()); + lt_handle, matmul_desp.descriptor(), &alpha, A, a_desp.descriptor(), B, b_desp.descriptor(), + &beta, nullptr, d_desp.descriptor(), D, d_desp.descriptor(), &heuristic_result.algo, + workspace.mutable_get(), workspace_size, at::cuda::getCurrentCUDAStream()); TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, at::cuda::blas::_cublasGetErrorEnum(status)); } From ca700211d73bd7357c5e3abf1e89722f2be887af Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 25 Aug 2024 13:28:12 +0000 Subject: [PATCH 09/16] support template --- include/flashinfer/bmm_fp8.cuh | 61 +++++++++++++++++++++++++++++++--- python/csrc/bmm_fp8.cu | 25 +++++++++----- 2 files changed, 73 insertions(+), 13 deletions(-) diff --git a/include/flashinfer/bmm_fp8.cuh b/include/flashinfer/bmm_fp8.cuh index 8baa8bad..75fb19bc 100644 --- a/include/flashinfer/bmm_fp8.cuh +++ b/include/flashinfer/bmm_fp8.cuh @@ -23,6 +23,7 @@ #include #include +#include namespace flashinfer { @@ -92,17 +93,40 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor +cudaDataType_t get_cuda_data_type() { + if constexpr (std::is_same_v) { + return CUDA_R_8F_E4M3; + } else if constexpr (std::is_same_v) { + return CUDA_R_8F_E5M2; + } else if constexpr (std::is_same_v) { + return CUDA_R_16BF; + } else if constexpr (std::is_same_v) { + return CUDA_R_16F; + } else { + throw std::runtime_error("Unsupported type"); + } +} + +template +void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size, int m, int n, + int k) { auto matmul_desp = CuBlasLtMatmulDescriptor(CUBLAS_COMPUTE_32F, CUDA_R_32F); matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T); matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N); int8_t fast_accum = 1; matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fast_accum); - auto a_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, m, k, k, true); - auto b_desp = CuBlasLtMatrixLayout(CUDA_R_8F_E4M3, k, n, k); - auto d_desp = CuBlasLtMatrixLayout(CUDA_R_16BF, m, n, m); + cudaDataType_t a_type = get_cuda_data_type(); + cudaDataType_t b_type = get_cuda_data_type(); + cudaDataType_t d_type = get_cuda_data_type
(); + if (std::is_same_v && std::is_same_v) { + throw std::runtime_error("Unsupported combination: both A and B are e5m2"); + } + + auto a_desp = CuBlasLtMatrixLayout(a_type, m, k, k, true); + auto b_desp = CuBlasLtMatrixLayout(b_type, k, n, k); + auto d_desp = CuBlasLtMatrixLayout(d_type, m, n, m); if (batch_size > 1) { int64_t stride_a = m * k; @@ -141,6 +165,33 @@ void bmm_fp8_internal_cublaslt(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, _ TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, at::cuda::blas::_cublasGetErrorEnum(status)); } +template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>( + const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n, + int k); + +template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, half>(const __nv_fp8_e4m3* A, + const __nv_fp8_e4m3* B, + half* D, int batch_size, + int m, int n, int k); + +template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, __nv_bfloat16>( + const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, __nv_bfloat16* D, int batch_size, int m, int n, + int k); + +template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, half>(const __nv_fp8_e4m3* A, + const __nv_fp8_e5m2* B, + half* D, int batch_size, + int m, int n, int k); + +template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, __nv_bfloat16>( + const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n, + int k); + +template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>(const __nv_fp8_e5m2* A, + const __nv_fp8_e4m3* B, + half* D, int batch_size, + int m, int n, int k); + } // namespace bmm_fp8 } // namespace flashinfer diff --git a/python/csrc/bmm_fp8.cu b/python/csrc/bmm_fp8.cu index 21a133a7..ca669718 100644 --- a/python/csrc/bmm_fp8.cu +++ b/python/csrc/bmm_fp8.cu @@ -20,6 +20,7 @@ #include #include "flashinfer_ops.h" +#include "pytorch_extension_utils.h" using namespace flashinfer; @@ -34,9 +35,12 @@ void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D) { TORCH_CHECK(A.size(2) == B.size(1), "Incompatible matrix sizes"); TORCH_CHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2), "Result tensor has incorrect shape"); - TORCH_CHECK(A.scalar_type() == torch::kFloat8_e4m3fn, "A must be Float8_e4m3fn"); - TORCH_CHECK(B.scalar_type() == torch::kFloat8_e4m3fn, "B must be Float8_e4m3fn"); - TORCH_CHECK(D.scalar_type() == torch::kBFloat16, "D must be BFloat16"); + TORCH_CHECK(A.scalar_type() == torch::kFloat8_e4m3fn || A.scalar_type() == torch::kFloat8_e5m2, + "A must be Float8_e4m3fn or Float8_e5m2"); + TORCH_CHECK(B.scalar_type() == torch::kFloat8_e4m3fn || B.scalar_type() == torch::kFloat8_e5m2, + "B must be Float8_e4m3fn or Float8_e5m2"); + TORCH_CHECK(D.scalar_type() == torch::kBFloat16 || D.scalar_type() == torch::kHalf, + "D must be BFloat16 or Half"); auto batch_size = A.size(0); auto m = A.size(1); @@ -46,9 +50,14 @@ void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D) { // PyTorch is row major by default. cuBLASLt is column major by default. // We need row major D as expected. // A ^ T * B = D, so D ^ T = B ^ T * A - if (D.scalar_type() == at::ScalarType::BFloat16) { - flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt( - static_cast<__nv_fp8_e4m3*>(B.data_ptr()), static_cast<__nv_fp8_e4m3*>(A.data_ptr()), - static_cast<__nv_bfloat16*>(D.data_ptr()), batch_size, n, m, k); - } + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(B.scalar_type(), b_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(A.scalar_type(), a_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(D.scalar_type(), d_type, [&] { + flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt( + static_cast(B.data_ptr()), static_cast(A.data_ptr()), + static_cast(D.data_ptr()), batch_size, n, m, k); + return true; + }); + }); + }); } From d649f3b4bfe53d2596ef82f4bf677f04258245ad Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 25 Aug 2024 13:56:07 +0000 Subject: [PATCH 10/16] update test --- python/tests/test_bmm_fp8.py | 51 +++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/python/tests/test_bmm_fp8.py b/python/tests/test_bmm_fp8.py index c8cbb06d..407e5a68 100644 --- a/python/tests/test_bmm_fp8.py +++ b/python/tests/test_bmm_fp8.py @@ -1,18 +1,39 @@ +import pytest import torch from flashinfer import bmm_fp8 -input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) -input_fp8 = input.to(torch.float8_e4m3fn) -mat2 = torch.randn([16, 64, 80], device="cuda", dtype=torch.bfloat16) -# mat2 row major -> column major -mat2_fp8 = mat2.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous() -# make original shape unchanged -mat2_fp8 = mat2_fp8.transpose(-1, -2) - -res = torch.empty([16, 48, 80], device="cuda", dtype=torch.bfloat16) -bmm_fp8(input_fp8, mat2_fp8, res) -res_bf16 = input @ mat2 - -torch.testing.assert_close( - res.float().cpu(), res_bf16.float().cpu(), rtol=1e-1, atol=1e-1 -) + +@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype): + if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2: + pytest.skip("Invalid combination: both input and mat2 are e5m2") + + input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) + input_fp8 = input.to(input_dtype) + + mat2 = torch.randn([16, 64, 80], device="cuda", dtype=torch.bfloat16) + # mat2 row major -> column major + mat2_fp8 = mat2.to(mat2_dtype).transpose(-1, -2).contiguous() + # make original shape unchanged + mat2_fp8 = mat2_fp8.transpose(-1, -2) + + res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype) + bmm_fp8(input_fp8, mat2_fp8, res) + + res_ref = (input @ mat2).to(res_dtype) + + res_float = res.float().cpu() + res_ref_float = res_ref.float().cpu() + + is_close = torch.isclose(res_float, res_ref_float, rtol=1e-1, atol=1e-1) + + total_elements = res_float.numel() + unequal_elements = torch.sum(~is_close).item() + unequal_percentage = (unequal_elements / total_elements) * 100 + assert unequal_percentage < 10 + + +if __name__ == "__main__": + pytest.main([__file__]) From 4f9fccf7fb8a63d742a68dbfad9351fb5830f411 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 25 Aug 2024 14:34:30 +0000 Subject: [PATCH 11/16] add scale --- include/flashinfer/bmm_fp8.cuh | 36 ++++++++++++++++++---------------- python/csrc/bmm_fp8.cu | 9 +++++++-- python/csrc/flashinfer_ops.h | 3 ++- python/flashinfer/bmm_fp8.py | 10 ++++++++-- 4 files changed, 36 insertions(+), 22 deletions(-) diff --git a/include/flashinfer/bmm_fp8.cuh b/include/flashinfer/bmm_fp8.cuh index 75fb19bc..36c2e1f9 100644 --- a/include/flashinfer/bmm_fp8.cuh +++ b/include/flashinfer/bmm_fp8.cuh @@ -109,14 +109,19 @@ cudaDataType_t get_cuda_data_type() { } template -void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size, int m, int n, - int k) { +void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size, int m, int n, int k, + const float* A_scale, const float* B_scale) { + const void* A_scale_ptr = static_cast(A_scale); + const void* B_scale_ptr = static_cast(B_scale); auto matmul_desp = CuBlasLtMatmulDescriptor(CUBLAS_COMPUTE_32F, CUDA_R_32F); matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T); matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N); int8_t fast_accum = 1; matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fast_accum); + matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, A_scale_ptr); + matmul_desp.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, B_scale_ptr); + cudaDataType_t a_type = get_cuda_data_type(); cudaDataType_t b_type = get_cuda_data_type(); cudaDataType_t d_type = get_cuda_data_type
(); @@ -167,30 +172,27 @@ void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size, template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>( const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n, - int k); + int k, const float* A_scale, const float* B_scale); -template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, half>(const __nv_fp8_e4m3* A, - const __nv_fp8_e4m3* B, - half* D, int batch_size, - int m, int n, int k); +template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, half>( + const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, half* D, int batch_size, int m, int n, int k, + const float* A_scale, const float* B_scale); template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, __nv_bfloat16>( const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, __nv_bfloat16* D, int batch_size, int m, int n, - int k); + int k, const float* A_scale, const float* B_scale); -template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, half>(const __nv_fp8_e4m3* A, - const __nv_fp8_e5m2* B, - half* D, int batch_size, - int m, int n, int k); +template void bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, half>( + const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, half* D, int batch_size, int m, int n, int k, + const float* A_scale, const float* B_scale); template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, __nv_bfloat16>( const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, __nv_bfloat16* D, int batch_size, int m, int n, - int k); + int k, const float* A_scale, const float* B_scale); -template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>(const __nv_fp8_e5m2* A, - const __nv_fp8_e4m3* B, - half* D, int batch_size, - int m, int n, int k); +template void bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>( + const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, half* D, int batch_size, int m, int n, int k, + const float* A_scale, const float* B_scale); } // namespace bmm_fp8 } // namespace flashinfer diff --git a/python/csrc/bmm_fp8.cu b/python/csrc/bmm_fp8.cu index ca669718..0f7da212 100644 --- a/python/csrc/bmm_fp8.cu +++ b/python/csrc/bmm_fp8.cu @@ -24,7 +24,8 @@ using namespace flashinfer; -void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D) { +void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, + torch::Tensor& A_scale, torch::Tensor& B_scale) { TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor"); TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor"); TORCH_CHECK(D.is_cuda(), "D must be a CUDA tensor"); @@ -42,6 +43,9 @@ void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D) { TORCH_CHECK(D.scalar_type() == torch::kBFloat16 || D.scalar_type() == torch::kHalf, "D must be BFloat16 or Half"); + TORCH_CHECK(A_scale.scalar_type() == torch::kFloat32 && B_scale.scalar_type() == torch::kFloat32, + "A_scale and B_scale must be Float32"); + auto batch_size = A.size(0); auto m = A.size(1); auto k = A.size(2); @@ -55,7 +59,8 @@ void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D) { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(D.scalar_type(), d_type, [&] { flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt( static_cast(B.data_ptr()), static_cast(A.data_ptr()), - static_cast(D.data_ptr()), batch_size, n, m, k); + static_cast(D.data_ptr()), batch_size, n, m, k, + static_cast(B_scale.data_ptr()), static_cast(A_scale.data_ptr())); return true; }); }); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index db9d0e0f..b37afeab 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -104,7 +104,8 @@ torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, torch::Tensor output_indptr, const std::string& bitorder); -void bmm_fp8(const torch::Tensor& input, const torch::Tensor& weight, torch::Tensor& result); +void bmm_fp8(const torch::Tensor& input, const torch::Tensor& weight, torch::Tensor& result, + torch::Tensor& A_scale, torch::Tensor& B_scale); class CutlassSegmentGEMMPyTorchWrapper { public: diff --git a/python/flashinfer/bmm_fp8.py b/python/flashinfer/bmm_fp8.py index 87439002..906a241c 100644 --- a/python/flashinfer/bmm_fp8.py +++ b/python/flashinfer/bmm_fp8.py @@ -30,5 +30,11 @@ raise e -def bmm_fp8(input: torch.Tensor, weight: torch.Tensor, result: torch.Tensor): - _kernels.bmm_fp8(input, weight, result) +def bmm_fp8( + input: torch.Tensor, + weight: torch.Tensor, + result: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, +): + _kernels.bmm_fp8(input, weight, result, A_scale, B_scale) From 032f29cc7252941778835185ee35e77dce6a4ad2 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 25 Aug 2024 14:41:32 +0000 Subject: [PATCH 12/16] upd --- python/csrc/flashinfer_ops.h | 2 +- python/flashinfer/bmm_fp8.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index b37afeab..ff5ec5e7 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -104,7 +104,7 @@ torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, torch::Tensor output_indptr, const std::string& bitorder); -void bmm_fp8(const torch::Tensor& input, const torch::Tensor& weight, torch::Tensor& result, +void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, torch::Tensor& A_scale, torch::Tensor& B_scale); class CutlassSegmentGEMMPyTorchWrapper { diff --git a/python/flashinfer/bmm_fp8.py b/python/flashinfer/bmm_fp8.py index 906a241c..da7bf3cc 100644 --- a/python/flashinfer/bmm_fp8.py +++ b/python/flashinfer/bmm_fp8.py @@ -31,10 +31,10 @@ def bmm_fp8( - input: torch.Tensor, - weight: torch.Tensor, - result: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + D: torch.Tensor, A_scale: torch.Tensor, B_scale: torch.Tensor, ): - _kernels.bmm_fp8(input, weight, result, A_scale, B_scale) + _kernels.bmm_fp8(A, B, D, A_scale, B_scale) From 2a184ad49c2ccc295cc04854808752892e7b0064 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Tue, 27 Aug 2024 03:53:33 +1000 Subject: [PATCH 13/16] update test --- python/tests/test_bmm_fp8.py | 37 +++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/python/tests/test_bmm_fp8.py b/python/tests/test_bmm_fp8.py index 407e5a68..a482efb6 100644 --- a/python/tests/test_bmm_fp8.py +++ b/python/tests/test_bmm_fp8.py @@ -1,8 +1,18 @@ import pytest import torch +import torch.nn.functional as F + from flashinfer import bmm_fp8 +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + abs_max = x.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1e-12) + scale = finfo.max / abs_max + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + @pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) @@ -11,28 +21,21 @@ def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype): pytest.skip("Invalid combination: both input and mat2 are e5m2") input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) - input_fp8 = input.to(input_dtype) + input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) - mat2 = torch.randn([16, 64, 80], device="cuda", dtype=torch.bfloat16) - # mat2 row major -> column major - mat2_fp8 = mat2.to(mat2_dtype).transpose(-1, -2).contiguous() - # make original shape unchanged - mat2_fp8 = mat2_fp8.transpose(-1, -2) + # mat2 row major -> column major + mat2 = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose( + -2, -1 + ) + mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype) - bmm_fp8(input_fp8, mat2_fp8, res) - - res_ref = (input @ mat2).to(res_dtype) - - res_float = res.float().cpu() - res_ref_float = res_ref.float().cpu() + bmm_fp8(input_fp8, mat2_fp8, res, input_inv_s, mat2_inv_s) - is_close = torch.isclose(res_float, res_ref_float, rtol=1e-1, atol=1e-1) + reference = torch.bmm(input, mat2) - total_elements = res_float.numel() - unequal_elements = torch.sum(~is_close).item() - unequal_percentage = (unequal_elements / total_elements) * 100 - assert unequal_percentage < 10 + cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) + assert cos_sim > 0.98 if __name__ == "__main__": From 7c226022acc8e96073a10903dc32922f336db5c8 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Tue, 27 Aug 2024 04:29:36 +1000 Subject: [PATCH 14/16] add doc string --- python/flashinfer/bmm_fp8.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/flashinfer/bmm_fp8.py b/python/flashinfer/bmm_fp8.py index da7bf3cc..8db12714 100644 --- a/python/flashinfer/bmm_fp8.py +++ b/python/flashinfer/bmm_fp8.py @@ -37,4 +37,23 @@ def bmm_fp8( A_scale: torch.Tensor, B_scale: torch.Tensor, ): + r"""BMM FP8 + + Parameters + ---------- + A: torch.Tensor + Input tensor, shape (b, m, k). + + B: torch.Tensor + Mat2 tensor, shape (b, k, n), should be column major. + + D: torch.Tensor + Out tensor, shape (b, m, n). + + A_scale: torch.Tensor + Scale tensor for A. + + B_scale: torch.Tensor + Scale tensor for B. + """ _kernels.bmm_fp8(A, B, D, A_scale, B_scale) From e53d50a4aeb74e8c078295b8393054f3de0841b6 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Tue, 27 Aug 2024 04:38:56 +1000 Subject: [PATCH 15/16] update --- docs/api/python/{group_gemm.rst => gemm.rst} | 0 docs/index.rst | 2 +- python/flashinfer/__init__.py | 13 ++--- python/flashinfer/bmm_fp8.py | 59 -------------------- python/flashinfer/{group_gemm.py => gemm.py} | 35 +++++++++++- 5 files changed, 40 insertions(+), 69 deletions(-) rename docs/api/python/{group_gemm.rst => gemm.rst} (100%) delete mode 100644 python/flashinfer/bmm_fp8.py rename python/flashinfer/{group_gemm.py => gemm.py} (93%) diff --git a/docs/api/python/group_gemm.rst b/docs/api/python/gemm.rst similarity index 100% rename from docs/api/python/group_gemm.rst rename to docs/api/python/gemm.rst diff --git a/docs/index.rst b/docs/index.rst index ce0129f7..0a4dd61e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -33,7 +33,7 @@ FlashInfer is a library for Large Language Models that provides high-performance api/python/sparse api/python/page api/python/sampling - api/python/group_gemm + api/python/gemm api/python/norm api/python/rope api/python/quantization diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index eb3b4a48..a89f985d 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -14,10 +14,11 @@ limitations under the License. """ +from .activation import gelu_tanh_and_mul, silu_and_mul from .cascade import ( - MultiLevelCascadeAttentionWrapper, BatchDecodeWithSharedPrefixPagedKVCacheWrapper, BatchPrefillWithSharedPrefixPagedKVCacheWrapper, + MultiLevelCascadeAttentionWrapper, merge_state, merge_state_in_place, merge_states, @@ -27,9 +28,7 @@ CUDAGraphBatchDecodeWithPagedKVCacheWrapper, single_decode_with_kv_cache, ) -from .activation import gelu_tanh_and_mul, silu_and_mul -from .bmm_fp8 import bmm_fp8 -from .group_gemm import SegmentGEMMWrapper +from .gemm import SegmentGEMMWrapper, bmm_fp8 from .norm import fused_add_rmsnorm, rmsnorm from .page import append_paged_kv_cache from .prefill import ( @@ -47,15 +46,15 @@ ) from .sampling import ( chain_speculative_sampling, + min_p_sampling_from_probs, sampling_from_probs, - top_k_renorm_prob, top_k_mask_logits, + top_k_renorm_prob, top_k_sampling_from_probs, - top_k_top_p_sampling_from_probs, top_k_top_p_sampling_from_logits, + top_k_top_p_sampling_from_probs, top_p_renorm_prob, top_p_sampling_from_probs, - min_p_sampling_from_probs, ) from .sparse import BlockSparseAttentionWrapper diff --git a/python/flashinfer/bmm_fp8.py b/python/flashinfer/bmm_fp8.py deleted file mode 100644 index 8db12714..00000000 --- a/python/flashinfer/bmm_fp8.py +++ /dev/null @@ -1,59 +0,0 @@ -""" -Copyright (c) 2024 by FlashInfer team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import torch - -# mypy: disable-error-code="attr-defined" -try: - from . import _kernels -except ImportError as e: - import logging - import os - - if os.environ.get("BUILD_DOC", "0") == "1": - _kernels = None - logging.warning("Kernels are not loaded in documentation build mode.") - else: - raise e - - -def bmm_fp8( - A: torch.Tensor, - B: torch.Tensor, - D: torch.Tensor, - A_scale: torch.Tensor, - B_scale: torch.Tensor, -): - r"""BMM FP8 - - Parameters - ---------- - A: torch.Tensor - Input tensor, shape (b, m, k). - - B: torch.Tensor - Mat2 tensor, shape (b, k, n), should be column major. - - D: torch.Tensor - Out tensor, shape (b, m, n). - - A_scale: torch.Tensor - Scale tensor for A. - - B_scale: torch.Tensor - Scale tensor for B. - """ - _kernels.bmm_fp8(A, B, D, A_scale, B_scale) diff --git a/python/flashinfer/group_gemm.py b/python/flashinfer/gemm.py similarity index 93% rename from python/flashinfer/group_gemm.py rename to python/flashinfer/gemm.py index 971f87c1..4b85b2de 100644 --- a/python/flashinfer/group_gemm.py +++ b/python/flashinfer/gemm.py @@ -14,16 +14,18 @@ limitations under the License. """ -import torch from typing import Optional + +import torch + from .utils import get_indptr # mypy: disable-error-code="attr-defined" try: from . import _kernels except ImportError as e: - import os import logging + import os if os.environ.get("BUILD_DOC", "0") == "1": _kernels = None @@ -194,3 +196,32 @@ def run( ) forward = run + + +def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + D: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, +): + r"""BMM FP8 + + Parameters + ---------- + A: torch.Tensor + Input tensor, shape (b, m, k). + + B: torch.Tensor + Mat2 tensor, shape (b, k, n), should be column major. + + D: torch.Tensor + Out tensor, shape (b, m, n). + + A_scale: torch.Tensor + Scale tensor for A. + + B_scale: torch.Tensor + Scale tensor for B. + """ + _kernels.bmm_fp8(A, B, D, A_scale, B_scale) From 125b91a238ed1afe96a208f6445d8788c0cc3063 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Tue, 27 Aug 2024 05:26:23 +1000 Subject: [PATCH 16/16] update --- python/flashinfer/gemm.py | 36 ++++++++++++++++++++++++++---------- python/tests/test_bmm_fp8.py | 2 +- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index 4b85b2de..81d51f79 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -201,27 +201,43 @@ def run( def bmm_fp8( A: torch.Tensor, B: torch.Tensor, - D: torch.Tensor, A_scale: torch.Tensor, B_scale: torch.Tensor, -): + dtype: torch.dtype, + out: torch.Tensor = None, +) -> torch.Tensor: r"""BMM FP8 Parameters ---------- A: torch.Tensor - Input tensor, shape (b, m, k). + Input tensor, shape (b, m, k), fp8 e4m3 or fp8 e5m2. B: torch.Tensor - Mat2 tensor, shape (b, k, n), should be column major. - - D: torch.Tensor - Out tensor, shape (b, m, n). + Mat2 tensor, shape (b, k, n), should be column major, fp8 e4m3 or fp8 e5m2. A_scale: torch.Tensor - Scale tensor for A. + Scale tensor for A, float. B_scale: torch.Tensor - Scale tensor for B. + Scale tensor for B, float. + + dtype: torch.dtype + out dtype, bf16 or fp16. + + out: torch.Tensor + Out tensor, shape (b, m, n), bf16 or fp16. + + Returns + ------- + out: torch.Tensor + Out tensor, shape (b, m, n), bf16 or fp16. """ - _kernels.bmm_fp8(A, B, D, A_scale, B_scale) + if out is None: + out = torch.empty( + (A.shape[0], A.shape[1], B.shape[2]), + device=A.device, + dtype=dtype, + ) + _kernels.bmm_fp8(A, B, out, A_scale, B_scale) + return out diff --git a/python/tests/test_bmm_fp8.py b/python/tests/test_bmm_fp8.py index a482efb6..59b36192 100644 --- a/python/tests/test_bmm_fp8.py +++ b/python/tests/test_bmm_fp8.py @@ -30,7 +30,7 @@ def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype): mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype) - bmm_fp8(input_fp8, mat2_fp8, res, input_inv_s, mat2_inv_s) + bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res) reference = torch.bmm(input, mat2)