diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index 8718885c82221..0b9c76b327c78 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -53,6 +53,12 @@ if(USE_CUDA AND USE_CUTLASS) ### Build cutlass runtime objects using TVM's 3rdparty/cutlass submodule set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass) set(TVM_CUTLASS_RUNTIME_SRCS "") + + # TODO: Should get rid of the postfix 'a' and test sm >= 90 + if (CMAKE_CUDA_ARCHITECTURES MATCHES "90|90a") + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu) + endif() + if (USE_CUDA_FP8) list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_fp8_gemm.cu) list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu) @@ -69,4 +75,4 @@ if(USE_CUDA AND USE_CUTLASS) list(APPEND TVM_RUNTIME_EXT_OBJS "${CUTLASS_RUNTIME_OBJS}") message(STATUS "Build with CUTLASS") -endif() \ No newline at end of file +endif() diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cu b/src/runtime/contrib/cutlass/fp16_group_gemm.cu new file mode 100644 index 0000000000000..aac7472f72611 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cu @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "group_gemm_runner.cuh" + + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +template <> +struct KernelTraits { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size + using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster +}; + +namespace tvm { +namespace runtime { + +template +void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, + NDArray out) { + // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + CHECK_EQ(x->ndim, 2); + CHECK_EQ(weight->ndim, 3); + CHECK_EQ(indptr->ndim, 1); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, 2); + int num_groups = weight->shape[0]; + int n = weight->shape[1]; + int k = weight->shape[2]; + cudaStream_t stream = static_cast((*func)().operator void*()); + cutlass_group_gemm(static_cast(x->data), static_cast(weight->data), + static_cast(indptr->data), static_cast(workspace->data), + workspace->shape[0], n, k, num_groups, static_cast(out->data), + stream); +} + +TVM_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90") + .set_body_typed( + tvm_cutlass_group_gemm_sm90); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED \ No newline at end of file diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm.cu b/src/runtime/contrib/cutlass/fp8_group_gemm.cu index cc732a97bd0f5..439d3d1361d22 100644 --- a/src/runtime/contrib/cutlass/fp8_group_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_group_gemm.cu @@ -23,181 +23,21 @@ #include #include -#include -#include -#include -#include +#include "group_gemm_runner.cuh" -#include "../../cuda/cuda_common.h" - -// clang-format off -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" -#include "cutlass/tensor_ref.h" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -// clang-format on - -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - if (error != cutlass::Status::kSuccess) { \ - std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ - << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } - -using namespace cute; -using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) -inline size_t aligned(size_t value, size_t alignment = 16) { - return (value + alignment - 1) / alignment * alignment; -} - -template -struct CutlassFP8GroupGemmRunner { - static constexpr int AlignmentA = - 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements - // (up to 16 bytes) - - static constexpr int AlignmentB = - 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements - // (up to 16 bytes) - - static constexpr int AlignmentC = - 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements - // (up to 16 bytes) - - // Core kernel configurations - using ElementAccumulator = float; // Element type for internal accumulation - using ArchTag = - cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +template <> +struct KernelTraits { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster - using StageCountType = - cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size - using KernelSchedule = - cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch - using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, - ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, - EpilogueSchedule>::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB, - ElementAccumulator, TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; - - using GemmKernel = - cutlass::gemm::kernel::GemmUniversal; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA; - using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB; - using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC; - using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD; - - void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const ElementC** ptr_C, - ElementC** ptr_D, - typename ProblemShape::UnderlyingProblemShape* problem_sizes, - typename ProblemShape::UnderlyingProblemShape* problem_sizes_host, - StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, StrideD* stride_D, - uint8_t* workspace, int64_t workspace_size, int num_groups, float alpha, - float beta, cudaStream_t stream) { - typename Gemm::EpilogueOutputOp::Params epilogue_params{ElementAccumulator(alpha), - ElementAccumulator(beta)}; - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGrouped, - {num_groups, problem_sizes, problem_sizes_host}, - {ptr_A, stride_A, ptr_B, stride_B}, - {epilogue_params, ptr_C, stride_C, ptr_D, stride_D}, - hw_info}; - Gemm gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(arguments)); - CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); - CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); - CUTLASS_CHECK(gemm_op.run()); - } }; -template -__global__ void prepare_group_gemm_arguments( - const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D, - typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A, - StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB* weight, ElementC* out, - int64_t* indptr, int64_t n, int64_t k, int64_t num_experts) { - int expert_id = threadIdx.x; - if (expert_id >= num_experts) return; - int prev_rows = expert_id == 0 ? 0 : indptr[expert_id - 1]; - ptr_A[expert_id] = x + prev_rows * k; - ptr_B[expert_id] = weight + expert_id * k * n; - ptr_D[expert_id] = out + prev_rows * n; - problem_sizes[expert_id] = {static_cast(indptr[expert_id] - prev_rows), - static_cast(n), static_cast(k)}; - stride_A[expert_id] = cute::make_stride(k, Int<1>{}, int64_t{0}); - stride_B[expert_id] = cute::make_stride(k, Int<1>{}, int64_t{0}); - stride_D[expert_id] = cute::make_stride(n, Int<1>{}, int64_t{0}); -} - -template -void cutlass_fp8_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace, - int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups, - ElementC* out, cudaStream_t stream) { - using Runner = CutlassFP8GroupGemmRunner; - using StrideA = typename Runner::StrideA; - using StrideB = typename Runner::StrideB; - using StrideC = typename Runner::StrideC; - - Runner runner; - std::ptrdiff_t offset = 0; - const ElementA** ptr_A = reinterpret_cast(workspace + offset); - offset += aligned(sizeof(ElementA*) * num_groups); - const ElementB** ptr_B = reinterpret_cast(workspace + offset); - offset += aligned(sizeof(ElementB*) * num_groups); - ElementC** ptr_D = reinterpret_cast(workspace + offset); - offset += aligned(sizeof(ElementC*) * num_groups); - typename ProblemShape::UnderlyingProblemShape* problem_sizes = - reinterpret_cast(workspace + offset); - offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * num_groups); - StrideA* stride_A = reinterpret_cast(workspace + offset); - offset += aligned(sizeof(StrideA) * num_groups); - StrideB* stride_B = reinterpret_cast(workspace + offset); - offset += aligned(sizeof(StrideB) * num_groups); - StrideC* stride_D = reinterpret_cast(workspace + offset); - offset += aligned(sizeof(StrideC) * num_groups); - prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>( - ptr_A, ptr_B, ptr_D, problem_sizes, stride_A, stride_B, stride_D, x, weight, out, indptr, n, - k, num_groups); - offset = aligned(offset, 256); - runner.run_group_gemm(ptr_A, ptr_B, const_cast(ptr_D), ptr_D, problem_sizes, - nullptr, stride_A, stride_B, stride_D, stride_D, workspace + offset, - workspace_size - offset, num_groups, 1.0f, 0.0f, stream); -} +template <> +struct KernelTraits : KernelTraits { +}; namespace tvm { namespace runtime { @@ -218,10 +58,10 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr int n = weight->shape[1]; int k = weight->shape[2]; cudaStream_t stream = static_cast((*func)().operator void*()); - cutlass_fp8_group_gemm(static_cast(x->data), static_cast(weight->data), - static_cast(indptr->data), static_cast(workspace->data), - workspace->shape[0], n, k, num_groups, static_cast(out->data), - stream); + cutlass_group_gemm(static_cast(x->data), static_cast(weight->data), + static_cast(indptr->data), static_cast(workspace->data), + workspace->shape[0], n, k, num_groups, static_cast(out->data), + stream); } TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16") diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh b/src/runtime/contrib/cutlass/group_gemm_runner.cuh new file mode 100644 index 0000000000000..f149a957a9b76 --- /dev/null +++ b/src/runtime/contrib/cutlass/group_gemm_runner.cuh @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +inline size_t aligned(size_t value, size_t alignment = 16) { + return (value + alignment - 1) / alignment * alignment; +} + + +template +struct KernelTraits; + +template +struct CutlassGroupGemmRunner { + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements + // (up to 16 bytes) + + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements + // (up to 16 bytes) + + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements + // (up to 16 bytes) + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = + cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = typename KernelTraits::TileShape; + using ClusterShape = typename KernelTraits::ClusterShape; + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + using KernelSchedule = typename KernelTraits::KernelSchedule; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA; + using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB; + using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC; + using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD; + + void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const ElementC** ptr_C, + ElementC** ptr_D, + typename ProblemShape::UnderlyingProblemShape* problem_sizes, + typename ProblemShape::UnderlyingProblemShape* problem_sizes_host, + StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, StrideD* stride_D, + uint8_t* workspace, int64_t workspace_size, int num_groups, float alpha, + float beta, cudaStream_t stream) { + typename Gemm::EpilogueOutputOp::Params epilogue_params{ElementAccumulator(alpha), + ElementAccumulator(beta)}; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGrouped, + {num_groups, problem_sizes, problem_sizes_host}, + {ptr_A, stride_A, ptr_B, stride_B}, + {epilogue_params, ptr_C, stride_C, ptr_D, stride_D}, + hw_info}; + Gemm gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); + CUTLASS_CHECK(gemm_op.run()); + } +}; + +template +__global__ void prepare_group_gemm_arguments( + const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D, + typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A, + StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB* weight, ElementC* out, + int64_t* indptr, int64_t n, int64_t k, int64_t num_groups) { + int group_id = threadIdx.x; + if (group_id >= num_groups) return; + int prev_rows = group_id == 0 ? 0 : indptr[group_id - 1]; + ptr_A[group_id] = x + prev_rows * k; + ptr_B[group_id] = weight + group_id * k * n; + ptr_D[group_id] = out + prev_rows * n; + problem_sizes[group_id] = {static_cast(indptr[group_id] - prev_rows), + static_cast(n), static_cast(k)}; + stride_A[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0}); + stride_B[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0}); + stride_D[group_id] = cute::make_stride(n, Int<1>{}, int64_t{0}); +} + +template +void cutlass_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace, + int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups, + ElementC* out, cudaStream_t stream) { + using Runner = CutlassGroupGemmRunner; + using StrideA = typename Runner::StrideA; + using StrideB = typename Runner::StrideB; + using StrideC = typename Runner::StrideC; + + Runner runner; + std::ptrdiff_t offset = 0; + const ElementA** ptr_A = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementA*) * num_groups); + const ElementB** ptr_B = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementB*) * num_groups); + ElementC** ptr_D = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementC*) * num_groups); + typename ProblemShape::UnderlyingProblemShape* problem_sizes = + reinterpret_cast(workspace + offset); + offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * num_groups); + StrideA* stride_A = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideA) * num_groups); + StrideB* stride_B = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideB) * num_groups); + StrideC* stride_D = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideC) * num_groups); + prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>( + ptr_A, ptr_B, ptr_D, problem_sizes, stride_A, stride_B, stride_D, x, weight, out, indptr, n, + k, num_groups); + offset = aligned(offset, 256); + runner.run_group_gemm(ptr_A, ptr_B, const_cast(ptr_D), ptr_D, problem_sizes, + nullptr, stride_A, stride_B, stride_D, stride_D, workspace + offset, + workspace_size - offset, num_groups, 1.0f, 0.0f, stream); +} diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 11437f7d682ab..949b6fb71df00 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -2504,5 +2504,92 @@ def main( _test_batched_var_len_attention(Module, seq_lens, num_head, num_kv_head, head_size, window_size) +def test_grouped_gemm_lora(): + @I.ir_module + class Module: + @R.function + def main( + x: R.Tensor(("batch_size", 128), dtype="float16"), + matmul_weight_base: R.Tensor((128, 256), dtype="float16"), + indexed_matmul_weight_LA: R.Tensor( + ("routing_table_size", "matmul_weight_lora_r", 128), dtype="float16" + ), + indexed_matmul_weight_LB: R.Tensor( + ("routing_table_size", 256, "matmul_weight_lora_r"), dtype="float16" + ), + total_rows_before: R.Tensor(("routing_table_size",), dtype="int64"), + ) -> R.Tensor(("batch_size", 256), dtype="float16"): + batch_size = T.int64() + matmul_weight_lora_r = T.int64() + R.func_attr({"num_input": 2}) + with R.dataflow(): + base = R.matmul(x, matmul_weight_base, out_dtype="float16") + workspace = R.zeros(R.shape([4194304]), dtype="float16") + lora_A_out = relax.call_dps_packed( + "cutlass.group_gemm_fp16_sm90", + [ + x, + indexed_matmul_weight_LA, + total_rows_before, + workspace, + ], + out_sinfo=relax.TensorStructInfo( + (batch_size, matmul_weight_lora_r), + x.struct_info.dtype, + ), + ) + lora_B_out = relax.call_dps_packed( + "cutlass.group_gemm_fp16_sm90", + [ + lora_A_out, + indexed_matmul_weight_LB, + total_rows_before, + workspace, + ], + out_sinfo=relax.TensorStructInfo( + (batch_size, 256), + x.struct_info.dtype, + ), + ) + gv: R.Tensor((batch_size, 256), dtype="float16") = R.add(base, lora_B_out) + R.output(gv) + return gv + + batch_size = 16 + rank = 64 + num_lora = 3 + + inp = np.random.randn(batch_size, 128).astype("float16") + base_weight = np.random.randn(128, 256).astype("float16") + lora_A = np.random.randn(num_lora, 128, rank).astype("float16") + lora_B = np.random.randn(num_lora, rank, 256).astype("float16") + + out_np = np.dot(inp.astype("float32"), base_weight.astype("float32")) + + offsets = [0, 6, 12, 16] + for i in range(num_lora): + x = inp[offsets[i] : offsets[i + 1]].astype("float32") + lora_out = np.dot(np.dot(x, lora_A[i].astype("float32")), lora_B[i].astype("float32")) + out_np[offsets[i] : offsets[i + 1]] += lora_out + + out_np = out_np.astype("float16") + + with tvm.target.Target("cuda"): + mod = relax.transform.LegalizeOps()(Module) + mod = tvm.tir.transform.DefaultGPUSchedule()(mod) + + lora_A = np.transpose(lora_A, [0, 2, 1]) + lora_B = np.transpose(lora_B, [0, 2, 1]) + total_rows_before = np.array(offsets[1:], dtype="int64") + + out = build_and_run( + mod, + [inp, base_weight, lora_A, lora_B, total_rows_before], + "cuda", + ) + + assert np.mean(np.abs(out_np - out)) < 5e-2 + + if __name__ == "__main__": tvm.testing.main()