From 8b21235e9d4f30212d50ccbb1d60f6a7656542a6 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 12 Aug 2024 03:49:38 +0000 Subject: [PATCH] review comments --- csrc/cutlass_extensions/cute_utils.cuh | 116 ++---------------- csrc/cutlass_extensions/torch_utils.hpp | 91 +++++++++++--- .../machete/machete_mm_kernel.cuh | 105 +++++++++++----- .../machete/machete_mm_launcher.cuh | 58 +++++---- .../machete/machete_prepack_launcher.cuh | 45 ++++--- 5 files changed, 223 insertions(+), 192 deletions(-) diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh index 14aa51703b6c5..83d58a9a867e1 100644 --- a/csrc/cutlass_extensions/cute_utils.cuh +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -1,112 +1,20 @@ #pragma once #include +#include +namespace cute { //////////////////////////////////////////////////////////////////// -// make_cute_stride -// - instantiates a stride object thats correctly populated base -// on the shape of the tensor and the stride type passed in, -// for example: -// - if s = Stride and shape = {M, N, L} then the stride will be -// constructed as {N, 1}, i.e. Row Major -// - if s = Stride<_1, int> and shape = {M, N, L} then the stride will be -// constructed as {1, M}, i.e. Column Major -// - if s = Stride and shape = {M, N, L} then the stride -// will be constructed as {N, 1, M * N}, i.e. Row Major Batched -// - etc. +// layout utils //////////////////////////////////////////////////////////////////// -// -// Row Major Batched -// -template -CUTLASS_HOST_DEVICE cute::Stride> make_cute_stride( - cute::Stride> s, cute::Shape shape_MNL) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set " - "dynamically. Static strides not supported."); - auto s_copy = s; - cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MNL)); - return s_copy; -} - -template -CUTLASS_HOST_DEVICE cute::Stride> make_cute_stride( - cute::Stride> s, int M, int N, int L) { - return make_cute_stride(s, cute::make_shape(M, N, L)); -} - -template -CUTLASS_HOST_DEVICE cute::Stride, IntT> make_cute_stride( - cute::Stride, IntT> s, cute::Shape shape_MNL) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set " - "dynamically. Static strides not supported."); - auto s_copy = s; - cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MNL)); - return s_copy; -} - -template -CUTLASS_HOST_DEVICE cute::Stride, IntT> make_cute_stride( - cute::Stride, IntT> s, int M, int N, int L) { - return make_cute_stride(s, cute::make_shape(M, N, L)); -} - -// -// Row Major Batched -// -template -CUTLASS_HOST_DEVICE cute::Stride, int64_t> make_cute_stride( - cute::Stride, int64_t> s, - cute::Shape shape_MNL) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set " - "dynamically. Static strides not supported."); - auto s_copy = s; - cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MNL)); - int batch_count = cute::get<2>(shape_MNL); - if (batch_count > 1) { - cute::get<2>(s_copy) = - static_cast(cute::get<0>(shape_MNL) * cute::get<1>(shape_MNL)); - } else { - cute::get<2>(s_copy) = static_cast(0); - } - return s_copy; -} - -template -CUTLASS_HOST_DEVICE cute::Stride, int64_t> make_cute_stride( - cute::Stride, int64_t> s, int M, int N, int L) { - return make_cute_stride(s, cute::make_shape(M, N, L)); -} - -// -// Col Major Batched -// -template -CUTLASS_HOST_DEVICE cute::Stride, IntT, int64_t> make_cute_stride( - cute::Stride, IntT, int64_t> s, - cute::Shape shape_MNL) { - static_assert(std::is_integral_v, - "Stride must have an integral type so it can be set " - "dynamically. Static strides not supported."); - auto s_copy = s; - cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MNL)); - int batch_count = cute::get<2>(shape_MNL); - if (batch_count > 1) { - cute::get<2>(s_copy) = - static_cast(cute::get<0>(shape_MNL) * cute::get<1>(shape_MNL)); - } else { - cute::get<2>(s_copy) = static_cast(0); - } - return s_copy; -} - -template -CUTLASS_HOST_DEVICE cute::Stride, IntT, int64_t> make_cute_stride( - cute::Stride, IntT, int64_t> s, int M, int N, int L) { - return make_cute_stride(s, cute::make_shape(M, N, L)); +// Permute layout based on indices, example: +// permute_layout<1, 0>(layout) will swap the two dimensions +// permute_layout<0, 2, 1>(layout) will swap the last two dimensions +template +auto permute_layout(Layout l) { + static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch"); + return cute::make_layout(cute::get(l)...); } //////////////////////////////////////////////////////////////////// @@ -120,4 +28,6 @@ static constexpr auto get_logical_ptr(PointerType* ptr) { } else { return ptr; } -} \ No newline at end of file +} + +}; // namespace cute diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp index a70a2f201f361..84b7c54a430a5 100644 --- a/csrc/cutlass_extensions/torch_utils.hpp +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -2,6 +2,7 @@ #include +#include "cute/layout.hpp" #include "cutlass/layout/matrix.h" #include "cutlass/bfloat16.h" #include "cutlass/half.h" @@ -9,34 +10,86 @@ using ColumnMajor = typename cutlass::layout::ColumnMajor; using RowMajor = typename cutlass::layout::RowMajor; -static inline bool is_row_major(torch::Tensor const tensor) { - TORCH_CHECK(tensor.dim() == 2); - return tensor.is_contiguous(); +namespace cute { + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g, + seq) { + return g(f(get(static_cast(t)), I)...); } -static inline bool is_column_major(torch::Tensor const tensor) { - TORCH_CHECK(tensor.dim() == 2); - return tensor.stride(0) == 1 && tensor.stride(1) == tensor.size(0); +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq) { + return make_shape(f(I)...); } -template -T* data_ptr(torch::Tensor const tensor, char const* name) { - if constexpr (std::is_same_v) { - TORCH_CHECK(is_row_major(tensor), "Expected ", name, " to be RowMajor"); - } else if constexpr (std::is_same_v) { - TORCH_CHECK(is_column_major(tensor), "Expected ", name, - " to be ColumnMajor"); +}; // namespace detail + +template +CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) { + if constexpr (is_tuple::value) { + return detail::tapply_with_idx( + t, f, [](auto const&... a) { return cute::make_tuple(a...); }, + tuple_seq{}); } else { - TORCH_CHECK(false, "Unknown Layout"); + return f(t); } - return reinterpret_cast(tensor.data_ptr()); + CUTE_GCC_UNREACHABLE; +} + +// calls: make_shape(f(0), f(1), ..., f(N-1)) +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) { + return detail::make_shape_from_idx(f, make_seq{}); } -template -T* maybe_data_ptr(c10::optional maybe_tensor, - char const* name) { - return (maybe_tensor) ? data_ptr(*maybe_tensor, name) : nullptr; +}; // namespace cute + +template +static inline auto make_cute_layout(torch::Tensor const& tensor, + std::string_view name = "tensor") { + TORCH_CHECK(tensor.dim() <= rank(Stride{})); + auto stride = cute::transform_with_idx( + Stride{}, [&](auto const& stride_ele, auto const& idx) { + using StrideEle = std::decay_t; + + if (tensor.dim() <= idx) { + return StrideEle{}; + } + + if constexpr (cute::is_static_v) { + TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", name, + ".stride(", idx, ") to be ", StrideEle::value); + return StrideEle{}; + } else { + return tensor.stride(idx); + } + }); + + auto shape = cute::make_shape_from_idx([&](auto const& idx) { + if (idx < tensor.dim()) + return tensor.size(idx); + else + return int64_t(1); + }); + + return make_layout(shape, stride); +} + +template +static inline auto maybe_make_cute_layout( + c10::optional const& tensor, + std::string_view name = "tensor") { + using Layout = decltype(make_cute_layout(*tensor)); + + if (tensor) { + return std::optional{make_cute_layout(*tensor, name)}; + } else { + return std::optional{}; + } } // diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index 6d8c734a86f64..36fcfefa748de 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -45,8 +45,9 @@ struct MacheteKernelTemplate { using ElementB = ElementB_; using ElementD = ElementD_; using ElementC = cute::conditional_t; - using ElementZero = ZeroT; - using ElementScale = ScaleT; + using ElementZ = ZeroT; + using ElementS = ScaleT; + using ElementAccumulator = AccumulatorT; // Element type for internal accumulation using ElementCompute = AccumulatorT; // For Epilogue @@ -54,8 +55,8 @@ struct MacheteKernelTemplate { using BTypeTuple = cute::conditional_t< with_scales, cute::conditional_t, - cute::tuple>, + cute::tuple, + cute::tuple>, ElementB>; using LayoutA = cutlass::layout::RowMajor; @@ -65,6 +66,13 @@ struct MacheteKernelTemplate { // not actually used since B has the prepacked layout, but required by cutlass using _LayoutB = cutlass::layout::ColumnMajor; + // Interface strides expected by create_arguments (will get transposed) + using StrideA = cutlass::detail::TagToStrideA_t; + using StrideC = cutlass::detail::TagToStrideA_t; + using StrideD = cutlass::detail::TagToStrideA_t; + using StrideS = cutlass::detail::TagToStrideA_t; + using StrideZ = StrideS; + using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; using LayoutC_Transpose = @@ -115,11 +123,6 @@ struct MacheteKernelTemplate { CollectiveMainloop, CollectiveEpilogue, TileScheduler>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - using StrideA = cutlass::detail::TagToStrideA_t; - using StrideC = typename GemmKernel::StrideC; - using StrideD = typename GemmKernel::StrideD; - using StrideS = typename CollectiveMainloop::StrideScale; - // stride_B is unused (since B is prepacked), but still required by cutlass using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>; @@ -127,41 +130,81 @@ struct MacheteKernelTemplate { using MainloopArguments = typename GemmKernel::MainloopArguments; using EpilogueArguments = typename GemmKernel::EpilogueArguments; - static Arguments create_arguments(cudaStream_t stream, int M, int N, int K, - ElementA const* A, ElementB const* B, - ElementC const* C, ElementD* D, - ElementScale const* scales, - ElementZero const* zeros, - ElementCompute alpha, ElementCompute beta, - std::optional maybe_group_size) { + template + static Arguments create_arguments( + cudaStream_t stream, + ElementA const* A_ptr, // A is an MxK matrix + Layout const& layout_A, + ElementB const* B_ptr, // B is an KxN prepacked matrix + ElementD* D_ptr, // D is an MxN matrix + Layout const& layout_D, + ElementC const* C_ptr, // C is an MxN matrix + std::optional> const& layout_C, + ElementS const* S_ptr, // S is an scale_KxN matrix + std::optional> const& layout_S, + ElementZ const* Z_ptr, // Z is an scale_KxN matrix + std::optional> const& layout_Z, + ElementCompute alpha, ElementCompute beta, + std::optional maybe_group_size) { static_assert(!with_zeropoints || with_scales); - TORCH_CHECK(with_C || (!with_C && beta == 0)); - TORCH_CHECK(with_scales || !scales); - TORCH_CHECK(with_zeropoints || !zeros); - static int constexpr L = 1; + int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); + int const group_size = maybe_group_size.value_or(K); int const scale_k = (K + group_size - 1) / group_size; - // stride_B is unused (since B is prepacked), but still required by cutlass - auto stride_A = make_cute_stride(StrideA{}, N, K, L); - auto stride_B = make_cute_stride(_StrideB{}, M, K, L); - auto stride_C = make_cute_stride(StrideC{}, N, M, L); - auto stride_D = make_cute_stride(StrideD{}, N, M, L); - auto stride_S = make_cute_stride(StrideS{}, N, scale_k, L); + TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); + TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N); + + if constexpr (with_C) { + TORCH_CHECK(C_ptr && layout_C); + } else { + TORCH_CHECK(!C_ptr, "C not supported"); + } + + if constexpr (with_scales) { + TORCH_CHECK(S_ptr && layout_S); + TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N)); + } else { + TORCH_CHECK(!S_ptr, "Scales not supported"); + } + + if constexpr (with_zeropoints) { + TORCH_CHECK(Z_ptr && layout_Z); + TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N)); + TORCH_CHECK(layout_S && *layout_Z == *layout_S, + "Scales and zeros must have the same layout"); + } else { + TORCH_CHECK(!Z_ptr, "Zeropoints not supported"); + } + + // Transpose A and D + // A doen't need to be tranposed since cutlass expects a NxK matrix + // for B (which is At) + auto stride_At = layout_A.stride(); + auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); + auto stride_Ct = stride_Dt; + if (layout_C) { + stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride(); + } MainloopArguments mainloop_arguments{}; EpilogueArguments epilogue_arguments{ - {alpha, beta}, C, stride_C, D, stride_D}; + {alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt}; if constexpr (with_scales && with_zeropoints) { - mainloop_arguments = MainloopArguments{ - B, stride_B, A, stride_A, scales, stride_S, group_size, zeros}; + auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, + S_ptr, stride_S, group_size, Z_ptr}; } else if constexpr (with_scales) { + auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); mainloop_arguments = MainloopArguments{ - B, stride_B, A, stride_A, scales, stride_S, group_size}; + B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size}; } else { - mainloop_arguments = MainloopArguments{B, stride_B, A, stride_A}; + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At}; } return Arguments{cutlass::gemm::GemmUniversalMode::kGemm, diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index 0ad9af656d05c..e5aa55baafb20 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -9,8 +9,8 @@ namespace machete { struct PyTorchArguments { - torch::Tensor const A; - torch::Tensor const B; + torch::Tensor const& A; + torch::Tensor const& B; c10::optional const& scales; c10::optional const& zeros; c10::optional group_size; @@ -27,18 +27,18 @@ torch::Tensor run_impl(PyTorchArguments args) { auto device = args.A.device(); auto stream = at::cuda::getCurrentCUDAStream(device.index()); - using ElementA = typename MacheteKernel::ElementA; - using ElementB = typename MacheteKernel::ElementB; - using ElementC = typename MacheteKernel::ElementC; - using ElementD = typename MacheteKernel::ElementD; - using ElementScale = typename MacheteKernel::ElementScale; - using ElementZero = typename MacheteKernel::ElementZero; + using EleA = typename MacheteKernel::ElementA; + using EleB = typename MacheteKernel::ElementB; + using EleC = typename MacheteKernel::ElementC; + using EleD = typename MacheteKernel::ElementD; + using EleS = typename MacheteKernel::ElementS; + using EleZ = typename MacheteKernel::ElementZ; - using LayoutA = typename MacheteKernel::LayoutA; - using LayoutC = typename MacheteKernel::LayoutC; - using LayoutD = typename MacheteKernel::LayoutD; - using LayoutScale = typename MacheteKernel::LayoutScale; - using LayoutZero = typename MacheteKernel::LayoutScale; + using StrideA = typename MacheteKernel::StrideA; + using StrideC = typename MacheteKernel::StrideC; + using StrideD = typename MacheteKernel::StrideD; + using StrideS = typename MacheteKernel::StrideS; + using StrideZ = typename MacheteKernel::StrideZ; int M = args.A.size(0); int N = args.B.size(1); @@ -47,23 +47,31 @@ torch::Tensor run_impl(PyTorchArguments args) { // Allocate output torch::Tensor D = torch::empty({M, N}, torch::TensorOptions() - .dtype(equivalent_scalar_type_v) + .dtype(equivalent_scalar_type_v) .device(device)); - auto A_ptr = data_ptr(args.A, "A"); - auto B_ptr = data_ptr(args.B, "B"); - auto D_ptr = data_ptr(D, "D"); - auto C_ptr = maybe_data_ptr(args.C, "C"); - auto scales_ptr = - maybe_data_ptr(args.scales, "scales"); - auto zeros_ptr = - maybe_data_ptr(args.zeros, "zeros"); + auto const &A = args.A, &B = args.B; + auto const &C = args.C, &scales = args.scales, &zeros = args.zeros; + + auto layout_A = make_cute_layout(A, "A"); + auto layout_D = make_cute_layout(D, "D"); + auto layout_C = maybe_make_cute_layout(C, "C"); + auto layout_S = maybe_make_cute_layout(scales, "scales"); + auto layout_Z = maybe_make_cute_layout(zeros, "zeros"); + + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.mutable_data_ptr()); + auto C_ptr = static_cast(C ? C->const_data_ptr() : nullptr); + auto S_ptr = + static_cast(scales ? scales->const_data_ptr() : nullptr); + auto Z_ptr = + static_cast(zeros ? zeros->const_data_ptr() : nullptr); auto arguments = MacheteKernel::create_arguments( - stream, M, N, K, A_ptr, B_ptr, C_ptr, D_ptr, scales_ptr, zeros_ptr, - args.alpha.value_or(1), args.beta.value_or(0), + stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, + layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), args.group_size.value_or(K)); - TORCH_CHECK(MacheteKernel::can_implement(arguments), "Machete kernel cannot be run with these arguments"); diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh index 9f2cbb1bc495b..0531b02fef6a7 100644 --- a/csrc/quantization/machete/machete_prepack_launcher.cuh +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -8,30 +8,47 @@ namespace machete { template torch::Tensor prepack_impl(torch::Tensor const B) { const at::cuda::OptionalCUDAGuard device_guard(device_of(B)); + using ElementB = typename PrepackedLayoutB::ElementB; auto device = B.device(); auto stream = at::cuda::getCurrentCUDAStream(device.index()); - - using ElementB = typename PrepackedLayoutB::ElementB; - using StrideB = cutlass::detail::TagToStrideB_t; - - auto B_ptr = data_ptr(B, "B"); - - auto elements_per_storage_item = + auto B_ptr = static_cast(B.const_data_ptr()); + // elements per storage item for B + auto eles_per_storage = (B.dtype().itemsize() * 8) / cute::sizeof_bits_v; - int N = B.size(0) * elements_per_storage_item; - int M = B.size(1); + // torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to + // match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L) + auto Bt_packed = B.t(); - auto const shape_Bt = cute::make_shape(M, N, 1); - auto const stride_B = make_cute_stride(StrideB{}, shape_Bt); + using StrideB = cutlass::detail::TagToStrideB_t; + auto const l_Bt_packed = make_cute_layout(Bt_packed, "B"); + + // convert (N,packed_K,L) layout to (N,K,L) layout + // in effect we want to do: blocked_product(layout_Bt_packed, + // make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}), + // Step<_1, _0, _2>{})); + // but blocked_product does not support dynamic strides so we implement the + // equivalent manually, + // new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L) + // new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage) + // when s1 == 1 + TORCH_CHECK(stride<1>(l_Bt_packed) == 1); + // clang-format off + auto const layout_Bt = make_layout( + transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) { + return idx == 1 ? ele * eles_per_storage : ele; + }), + transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) { + return idx != 1 ? ele * eles_per_storage : ele; + })); + // clang-format on // Allocate output torch::Tensor D = torch::empty_like(B); - prepack_B( - stream, B_ptr, make_layout(shape_Bt, stride_B), - reinterpret_cast(D.mutable_data_ptr())); + prepack_B(stream, B_ptr, layout_Bt, + static_cast(D.mutable_data_ptr())); return D; };