Skip to content

Commit

Permalink
Add fp16 grouped gemm support for sm90 (apache#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and vinx13 committed Mar 19, 2024
1 parent 84f718a commit d8dc9f2
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 172 deletions.
8 changes: 7 additions & 1 deletion cmake/modules/contrib/CUTLASS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ if(USE_CUDA AND USE_CUTLASS)
### Build cutlass runtime objects using TVM's 3rdparty/cutlass submodule
set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
set(TVM_CUTLASS_RUNTIME_SRCS "")

# TODO: Should get rid of the postfix 'a' and test sm >= 90
if (CMAKE_CUDA_ARCHITECTURES MATCHES "90|90a")
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu)
endif()

if (USE_CUDA_FP8)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_fp8_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu)
Expand All @@ -69,4 +75,4 @@ if(USE_CUDA AND USE_CUTLASS)
list(APPEND TVM_RUNTIME_EXT_OBJS "${CUTLASS_RUNTIME_OBJS}")

message(STATUS "Build with CUTLASS")
endif()
endif()
70 changes: 70 additions & 0 deletions src/runtime/contrib/cutlass/fp16_group_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <cuda_fp16.h>
#include <float.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include "group_gemm_runner.cuh"


#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)

template <>
struct KernelTraits<cutlass::half_t> {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
};

namespace tvm {
namespace runtime {

template <typename ElementA, typename ElementB, typename ElementC>
void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDArray workspace,
NDArray out) {
// Workspace is used for storing device-side group gemm arguments and cutlass internal workspace.
// Recommened size is 4MB.
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
CHECK_EQ(x->ndim, 2);
CHECK_EQ(weight->ndim, 3);
CHECK_EQ(indptr->ndim, 1);
CHECK_EQ(workspace->ndim, 1);
CHECK_EQ(out->ndim, 2);
int num_groups = weight->shape[0];
int n = weight->shape[1];
int k = weight->shape[2];
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
cutlass_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, static_cast<ElementC*>(out->data),
stream);
}

TVM_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90")
.set_body_typed(
tvm_cutlass_group_gemm_sm90<cutlass::half_t, cutlass::half_t, cutlass::half_t>);

} // namespace runtime
} // namespace tvm

#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
182 changes: 11 additions & 171 deletions src/runtime/contrib/cutlass/fp8_group_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,181 +23,21 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include <fstream>
#include <iostream>
#include <sstream>
#include <vector>
#include "group_gemm_runner.cuh"

#include "../../cuda/cuda_common.h"

// clang-format off
#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on

#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}

using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per group

#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)

inline size_t aligned(size_t value, size_t alignment = 16) {
return (value + alignment - 1) / alignment * alignment;
}

template <typename ElementA, typename ElementB, typename ElementC,
typename LayoutA = cutlass::layout::RowMajor,
typename LayoutB = cutlass::layout::ColumnMajor,
typename LayoutC = cutlass::layout::RowMajor>
struct CutlassFP8GroupGemmRunner {
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements
// (up to 16 bytes)

static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements
// (up to 16 bytes)

static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements
// (up to 16 bytes)

// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag =
cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
template <>
struct KernelTraits<cutlass::float_e4m3_t> {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator,
ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC,
EpilogueSchedule>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA;
using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB;
using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC;
using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD;

void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const ElementC** ptr_C,
ElementC** ptr_D,
typename ProblemShape::UnderlyingProblemShape* problem_sizes,
typename ProblemShape::UnderlyingProblemShape* problem_sizes_host,
StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, StrideD* stride_D,
uint8_t* workspace, int64_t workspace_size, int num_groups, float alpha,
float beta, cudaStream_t stream) {
typename Gemm::EpilogueOutputOp::Params epilogue_params{ElementAccumulator(alpha),
ElementAccumulator(beta)};

cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGrouped,
{num_groups, problem_sizes, problem_sizes_host},
{ptr_A, stride_A, ptr_B, stride_B},
{epilogue_params, ptr_C, stride_C, ptr_D, stride_D},
hw_info};
Gemm gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(arguments));
CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
CUTLASS_CHECK(gemm_op.run());
}
};

template <typename ElementA, typename ElementB, typename ElementC, typename StrideA,
typename StrideB, typename StrideC>
__global__ void prepare_group_gemm_arguments(
const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D,
typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A,
StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB* weight, ElementC* out,
int64_t* indptr, int64_t n, int64_t k, int64_t num_experts) {
int expert_id = threadIdx.x;
if (expert_id >= num_experts) return;
int prev_rows = expert_id == 0 ? 0 : indptr[expert_id - 1];
ptr_A[expert_id] = x + prev_rows * k;
ptr_B[expert_id] = weight + expert_id * k * n;
ptr_D[expert_id] = out + prev_rows * n;
problem_sizes[expert_id] = {static_cast<int>(indptr[expert_id] - prev_rows),
static_cast<int>(n), static_cast<int>(k)};
stride_A[expert_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
stride_B[expert_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
stride_D[expert_id] = cute::make_stride(n, Int<1>{}, int64_t{0});
}

template <typename ElementA, typename ElementB, typename ElementC>
void cutlass_fp8_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace,
int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups,
ElementC* out, cudaStream_t stream) {
using Runner = CutlassFP8GroupGemmRunner<ElementA, ElementB, ElementC>;
using StrideA = typename Runner::StrideA;
using StrideB = typename Runner::StrideB;
using StrideC = typename Runner::StrideC;

Runner runner;
std::ptrdiff_t offset = 0;
const ElementA** ptr_A = reinterpret_cast<const ElementA**>(workspace + offset);
offset += aligned(sizeof(ElementA*) * num_groups);
const ElementB** ptr_B = reinterpret_cast<const ElementB**>(workspace + offset);
offset += aligned(sizeof(ElementB*) * num_groups);
ElementC** ptr_D = reinterpret_cast<ElementC**>(workspace + offset);
offset += aligned(sizeof(ElementC*) * num_groups);
typename ProblemShape::UnderlyingProblemShape* problem_sizes =
reinterpret_cast<typename ProblemShape::UnderlyingProblemShape*>(workspace + offset);
offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * num_groups);
StrideA* stride_A = reinterpret_cast<StrideA*>(workspace + offset);
offset += aligned(sizeof(StrideA) * num_groups);
StrideB* stride_B = reinterpret_cast<StrideB*>(workspace + offset);
offset += aligned(sizeof(StrideB) * num_groups);
StrideC* stride_D = reinterpret_cast<StrideC*>(workspace + offset);
offset += aligned(sizeof(StrideC) * num_groups);
prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>(
ptr_A, ptr_B, ptr_D, problem_sizes, stride_A, stride_B, stride_D, x, weight, out, indptr, n,
k, num_groups);
offset = aligned(offset, 256);
runner.run_group_gemm(ptr_A, ptr_B, const_cast<const ElementC**>(ptr_D), ptr_D, problem_sizes,
nullptr, stride_A, stride_B, stride_D, stride_D, workspace + offset,
workspace_size - offset, num_groups, 1.0f, 0.0f, stream);
}
template <>
struct KernelTraits<cutlass::float_e5m2_t> : KernelTraits<cutlass::float_e4m3_t> {
};

namespace tvm {
namespace runtime {
Expand All @@ -218,10 +58,10 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr
int n = weight->shape[1];
int k = weight->shape[2];
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
cutlass_fp8_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, static_cast<ElementC*>(out->data),
stream);
cutlass_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, static_cast<ElementC*>(out->data),
stream);
}

TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16")
Expand Down
Loading

0 comments on commit d8dc9f2

Please sign in to comment.