Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support #10995

Merged
merged 102 commits into from
Dec 18, 2024

Conversation

dsikka
Copy link
Contributor

@dsikka dsikka commented Dec 8, 2024

Summary

  • Add sparse quantized and unquantized kernels for CUTLASS 3.x.
  • Add compressed tensors support for 2of4 Sparse Only, 2of4 Sparse + INT8/FP8 Quantized Models

From Neural Magic


at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = test_get_sm_version_num();
// Hopper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: what's this comment for?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe for a future PR but there should be more tests here, test more shapes, there should be and opcheck test (see test_cutlass_support_opcheck), a cuda graph test (see test_cutlass_cuda_graph). Use vllm/tests/kernels/test_cutlass.py as inspiration (with the exception of the azp stuff I assume)

@@ -361,7 +361,8 @@ def main(args: argparse.Namespace):
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
f"{total_output_tokens / elapsed_time:.2f} output tokens/s, "
f"{total_num_tokens=} | {total_output_tokens=}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like debug cruft and should be reverted if so

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 17 to 20
inline uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you put this in csrc/core/math.hpp? @SageMoore is adding similar utilities to that file in #10867

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 25 to 33
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
<< " at line: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should throw an exception here, and it should behave generally the same way that CUTLASS_CHECK does.
(I do like the line number reporting though, so it would be nice if you could add it to both)

Suggested change
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
<< " at line: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}
#define CUDA_CHECK(status) \
{ \
TORCH_CHECK(status == cudaSuccess, \
cudaGetErrorString(status)) \
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 1 to 43
#include <cudaTypedefs.h>

#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>

#include <iostream>
#include <sstream>
#include <vector>

#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/detail/dependent_false.hpp"

#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/common.hpp"

#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"

#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"

#include <iostream>

#include "cutlass/cutlass.h"

#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"

#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "sparse_scaled_mm_c3x.cuh"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean up these includes. I see some duplicates. Could you try to minimize the number of includes? I.E. no duplicates, and nothing that's unnecessary?

Also please turn clang-format off for the includes, as CUTLASS headers don't tolerate reordering.

// clang-format will break include orders
// clang-format off

#include "your.h"
#include "includes.h"
#include "here.h"

// clang-format on

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be pared down further.

For example:
"cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" already includes "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" and most of our CUTLASS kernels don't interact directly with the code in broadcast_load_epilogue_c3x.hpp so they should only include scaled_mm_epilogues_c3x.hpp.

However this sparsify_and_compress kernel doesn't use any epilogues at all so it shouldn't include either of them.

Could you take another look at these includes and the includes in your other kernels as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. The CUTLASS's CompressorUtility necessitates that a Gemm be defined with all operand types, schedules, etc with an epilogue, albeit the default. I had previously used my default gemm config with ScaledEpilogue for this Gemm but per this review, I replaced that with an on-the-spot Gemm kernel setup similar to the examples provided in CUTLASS. I am also mentioning this in a comment in the code now.

Comment on lines 76 to 77
// Just a dummy value
int32_t n = 128;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you expand on this comment?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was just needed to instantiate a problem shape to use the compressor utility in CUTLASS. I replaced it with 1 in the problem shape directly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put this in a comment in the code so that it is documented there

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 56 to 57
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any requirement for the divisibility of a.stride(0)? Do we test odd values of m?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Since we're doing column-major output in the kernels, there's no requirement. For row-major output, the batch size has to be a multiple of 8.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this was the weight matrix, so batch isn't relevant here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, my bad for misunderstanding. The intermediate dimension of the matmul should be divisible by 4 to be able to follow the 2:4 sparsity. So a.stride(0) % 4 == 0 must hold. I added a check for this divisibility.

Comment on lines 31 to 36

Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this comment is epilogue-specific and the epilogues are not defined in this file, I think this comment should be removed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 535 to 557
def cutlass_compress_entry(a: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor]:
assert (a.dtype in [
torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16
])

# e.dtype: torch.uint8 so elemsPerElemE = 8b / 2b_per_nz = 4
elemsPerElemE = 4

m = a.shape[0]
k = a.shape[1]
a_compressed = torch.empty((m, k // 2), dtype=a.dtype, device=a.device)
e = torch.empty((m, k // 2 // elemsPerElemE),
dtype=torch.uint8,
device=a.device)

if not (torch.ops._C.cutlass_compress_entry(a_compressed, e, a)):
raise ValueError

return a_compressed, e


def cutlass_scaled_sparse_mm(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add high-level comments for what these are doing? In particular could you describe what e is?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

vllm/_custom_ops.py Outdated Show resolved Hide resolved
vllm/_custom_ops.py Outdated Show resolved Hide resolved
csrc/sparse/cutlass/sparse_compressor.cu Outdated Show resolved Hide resolved
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me now, thanks for the hard work!

Copy link
Contributor

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM too, just left a few very minor refactor/comment nits. Thanks for the hardwork and iterations!

ops.def(
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor b,"
" Tensor e, Tensor a_scales,"
Copy link
Contributor

@LucasWilkinson LucasWilkinson Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you update argument naming to match, i.e. bt_nzs and bt_meta

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;

// Interface stride expected from the argument a (will get transposed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you elaborate on this a bit, i.e. add something about the fact that we compute C^t = B^t @ A^t but we assume B is transposed before compressing hence the bt_<x> naming

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

auto layout_A = make_cute_layout<StrideA>(a, "A");
auto layout_D = make_cute_layout<StrideD>(out, "D");

auto stride_At = layout_A.stride();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you add a comment here explaining why At is the same stride as A for cutlass

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


using GemmKernel = typename Gemm::GemmKernel;
typename GemmKernel::ProblemShape prob_shape{
(int)bt_nzs.size(0), (int)size<0>(layout_A), (int)size<1>(layout_A), 1};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should avoid c-style casts for consistency (use static_cast here)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// CUTLASS sparse matrix compressor
ops.def(
"cutlass_sparse_compress_entry(Tensor! a_compressed, Tensor! e,"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe update this to match the argument naming for cutlass_scaled_sparse_mm i.e. Tensor! a_nzs, Tensor! a_meta

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


/// Make A structured sparse by replacing elements with 0 and compress it
template <typename ElementA_, typename ElementAcc_>
bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe update this to match the argument naming for cutlass_scaled_sparse_mm i.e. Tensor! a_nzs, Tensor! a_meta

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 16, 2024
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe extract status first (like below) so this macro can directly wrap expressions like function calls and not double-evaluate them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

CMakeLists.txt Outdated
GIT_PROGRESS TRUE

# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
# GIT_SHALLOW FALSE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be uncommented as FALSE now?

Suggested change
# GIT_SHALLOW FALSE
GIT_SHALLOW FALSE

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sure. It's also the default I think but better be explicit as you said.

Comment on lines +105 to +106
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

future work: what about per-channel/per-token scales?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. We can also use that for benchmarking. I put this here only because it's similar to the dense benchmarking script.


@classmethod
def get_min_capability(cls) -> int:
return 90
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth leaving a note that this is due to cutlass 3.x kernel restrictions since we do have fp16+int8 support here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit 60508ff into vllm-project:main Dec 18, 2024
76 checks passed
SageMoore pushed a commit to neuralmagic/vllm that referenced this pull request Dec 19, 2024
…#10995)

Co-authored-by: Faraz Shahsavan <faraz.shahsavan@gmail.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
ProExpertProg added a commit to neuralmagic/vllm that referenced this pull request Dec 20, 2024
BKitor pushed a commit to BKitor/vllm that referenced this pull request Dec 30, 2024
…#10995)

Co-authored-by: Faraz Shahsavan <faraz.shahsavan@gmail.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants