Skip to content

Commit

Permalink
[Kernel] Refactor CUTLASS kernels to always take scales that reside o…
Browse files Browse the repository at this point in the history
…n the GPU (vllm-project#5137)
  • Loading branch information
tlrmchlsmth authored and joerunde committed Jun 4, 2024
1 parent ac902ef commit 4f7c5a1
Show file tree
Hide file tree
Showing 7 changed files with 445 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,27 @@
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/visitor_load.hpp from
// https://github.com/NVIDIA/cutlass It's beem modified to support either
// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x.
// Important because this saves us a factor 4x on the number of kernels
// compiled.
// https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either
// row/column or scalar broadcasting where the tensor being loaded from is
// always passed in via a device pointer. This lets one compiled kernel handle
// all cases of per-tensor or per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graph
// breaks when moving scales to the CPU.
//
#pragma once

// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off

#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
#include "cute/tensor.hpp"

// clang-format on

namespace cutlass::epilogue::threadblock {

using namespace cute;
Expand All @@ -59,9 +66,11 @@ template<
>
struct VisitorRowOrScalarBroadcast {

// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast.
struct Arguments {
Element const* ptr_row = nullptr;
Element null_default = Element(0);
bool row_broadcast = true;
StrideMNL dRow = {};
};

Expand Down Expand Up @@ -125,25 +134,25 @@ struct VisitorRowOrScalarBroadcast {
auto coord_v = filter(tC_cRow);
auto dst_v = filter(tC_rRow);

if (params_ptr->ptr_row) {
if (params_ptr->row_broadcast) {
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const*)&src_v(i), guard);
cutlass::arch::global_load<VecType, sizeof(VecType)>(
dst_v(i), (void const*)&src_v(i), guard);
}
} else {
// In this case we are loading from a scalar and broadcasting
VecType filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VecLength; i++) {
reinterpret_cast<Element*>(&filled_vec)[i] = params_ptr->null_default;
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
}

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if(get<1>(coord_v(i)) < n)
{
if (get<1>(coord_v(i)) < n) {
dst_v(i) = filled_vec;
}
}
Expand Down Expand Up @@ -208,9 +217,11 @@ template<
>
struct VisitorColOrScalarBroadcast {

// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast.
struct Arguments {
Element const* ptr_col = nullptr;
Element null_default = Element(0);
bool col_broadcast = true;
StrideMNL dCol = {};
};

Expand All @@ -230,11 +241,6 @@ struct VisitorColOrScalarBroadcast {

struct SharedStorage { };

// Global load type
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);

CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast() { }

Expand Down Expand Up @@ -267,7 +273,7 @@ struct VisitorColOrScalarBroadcast {
int m;

// This function is modified from VisitorColBroadcast
CUTLASS_DEVICE void
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rCol);

Expand All @@ -277,7 +283,7 @@ struct VisitorColOrScalarBroadcast {
pred(i) = get<0>(tC_cCol(i)) < m;
}

if (params_ptr->ptr_col) {
if (params_ptr->col_broadcast) {
// In this case we are loading from a column vector and broadcasting
copy_if(pred, tC_gCol, tC_rCol);
} else {
Expand All @@ -286,8 +292,8 @@ struct VisitorColOrScalarBroadcast {

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(dst_v); ++i) {
if(pred(i)){
dst_v(i) = params_ptr->null_default;
if (pred(i)) {
dst_v(i) = *(params_ptr->ptr_col);
}
}
}
Expand Down
Loading

0 comments on commit 4f7c5a1

Please sign in to comment.