Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 12, 2024
1 parent d5b896e commit 8b21235
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 192 deletions.
116 changes: 13 additions & 103 deletions csrc/cutlass_extensions/cute_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,112 +1,20 @@
#pragma once

#include <cute/tensor.hpp>
#include <torch/all.h>
namespace cute {

////////////////////////////////////////////////////////////////////
// make_cute_stride
// - instantiates a stride object thats correctly populated base
// on the shape of the tensor and the stride type passed in,
// for example:
// - if s = Stride<int, _1> and shape = {M, N, L} then the stride will be
// constructed as {N, 1}, i.e. Row Major
// - if s = Stride<_1, int> and shape = {M, N, L} then the stride will be
// constructed as {1, M}, i.e. Column Major
// - if s = Stride<int, _1, int64_t> and shape = {M, N, L} then the stride
// will be constructed as {N, 1, M * N}, i.e. Row Major Batched
// - etc.
// layout utils
////////////////////////////////////////////////////////////////////

//
// Row Major Batched
//
template <class IntT>
CUTLASS_HOST_DEVICE cute::Stride<IntT, cute::Int<1>> make_cute_stride(
cute::Stride<IntT, cute::Int<1>> s, cute::Shape<int, int, int> shape_MNL) {
static_assert(std::is_integral_v<IntT>,
"Stride must have an integral type so it can be set "
"dynamically. Static strides not supported.");
auto s_copy = s;
cute::get<0>(s_copy) = static_cast<IntT>(cute::get<1>(shape_MNL));
return s_copy;
}

template <class IntT>
CUTLASS_HOST_DEVICE cute::Stride<IntT, cute::Int<1>> make_cute_stride(
cute::Stride<IntT, cute::Int<1>> s, int M, int N, int L) {
return make_cute_stride(s, cute::make_shape(M, N, L));
}

template <class IntT>
CUTLASS_HOST_DEVICE cute::Stride<cute::Int<1>, IntT> make_cute_stride(
cute::Stride<cute::Int<1>, IntT> s, cute::Shape<int, int, int> shape_MNL) {
static_assert(std::is_integral_v<IntT>,
"Stride must have an integral type so it can be set "
"dynamically. Static strides not supported.");
auto s_copy = s;
cute::get<1>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MNL));
return s_copy;
}

template <class IntT>
CUTLASS_HOST_DEVICE cute::Stride<cute::Int<1>, IntT> make_cute_stride(
cute::Stride<cute::Int<1>, IntT> s, int M, int N, int L) {
return make_cute_stride(s, cute::make_shape(M, N, L));
}

//
// Row Major Batched
//
template <class IntT>
CUTLASS_HOST_DEVICE cute::Stride<IntT, cute::Int<1>, int64_t> make_cute_stride(
cute::Stride<IntT, cute::Int<1>, int64_t> s,
cute::Shape<int, int, int> shape_MNL) {
static_assert(std::is_integral_v<IntT>,
"Stride must have an integral type so it can be set "
"dynamically. Static strides not supported.");
auto s_copy = s;
cute::get<0>(s_copy) = static_cast<IntT>(cute::get<1>(shape_MNL));
int batch_count = cute::get<2>(shape_MNL);
if (batch_count > 1) {
cute::get<2>(s_copy) =
static_cast<IntT>(cute::get<0>(shape_MNL) * cute::get<1>(shape_MNL));
} else {
cute::get<2>(s_copy) = static_cast<IntT>(0);
}
return s_copy;
}

template <class IntT>
CUTLASS_HOST_DEVICE cute::Stride<IntT, cute::Int<1>, int64_t> make_cute_stride(
cute::Stride<IntT, cute::Int<1>, int64_t> s, int M, int N, int L) {
return make_cute_stride(s, cute::make_shape(M, N, L));
}

//
// Col Major Batched
//
template <class IntT>
CUTLASS_HOST_DEVICE cute::Stride<cute::Int<1>, IntT, int64_t> make_cute_stride(
cute::Stride<cute::Int<1>, IntT, int64_t> s,
cute::Shape<int, int, int> shape_MNL) {
static_assert(std::is_integral_v<IntT>,
"Stride must have an integral type so it can be set "
"dynamically. Static strides not supported.");
auto s_copy = s;
cute::get<1>(s_copy) = static_cast<IntT>(cute::get<0>(shape_MNL));
int batch_count = cute::get<2>(shape_MNL);
if (batch_count > 1) {
cute::get<2>(s_copy) =
static_cast<IntT>(cute::get<0>(shape_MNL) * cute::get<1>(shape_MNL));
} else {
cute::get<2>(s_copy) = static_cast<IntT>(0);
}
return s_copy;
}

template <class IntT>
CUTLASS_HOST_DEVICE cute::Stride<cute::Int<1>, IntT, int64_t> make_cute_stride(
cute::Stride<cute::Int<1>, IntT, int64_t> s, int M, int N, int L) {
return make_cute_stride(s, cute::make_shape(M, N, L));
// Permute layout based on indices, example:
// permute_layout<1, 0>(layout) will swap the two dimensions
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
template <size_t... I, typename Layout>
auto permute_layout(Layout l) {
static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
return cute::make_layout(cute::get<I>(l)...);
}

////////////////////////////////////////////////////////////////////
Expand All @@ -120,4 +28,6 @@ static constexpr auto get_logical_ptr(PointerType* ptr) {
} else {
return ptr;
}
}
}

}; // namespace cute
91 changes: 72 additions & 19 deletions csrc/cutlass_extensions/torch_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,94 @@

#include <torch/all.h>

#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"

using ColumnMajor = typename cutlass::layout::ColumnMajor;
using RowMajor = typename cutlass::layout::RowMajor;

static inline bool is_row_major(torch::Tensor const tensor) {
TORCH_CHECK(tensor.dim() == 2);
return tensor.is_contiguous();
namespace cute {

namespace detail {

template <class T, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
seq<I...>) {
return g(f(get<I>(static_cast<T&&>(t)), I)...);
}

static inline bool is_column_major(torch::Tensor const tensor) {
TORCH_CHECK(tensor.dim() == 2);
return tensor.stride(0) == 1 && tensor.stride(1) == tensor.size(0);
template <class F, int... I>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
return make_shape(f(I)...);
}

template <typename T, typename Layout = RowMajor>
T* data_ptr(torch::Tensor const tensor, char const* name) {
if constexpr (std::is_same_v<Layout, RowMajor>) {
TORCH_CHECK(is_row_major(tensor), "Expected ", name, " to be RowMajor");
} else if constexpr (std::is_same_v<Layout, ColumnMajor>) {
TORCH_CHECK(is_column_major(tensor), "Expected ", name,
" to be ColumnMajor");
}; // namespace detail

template <class T, class F>
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
if constexpr (is_tuple<T>::value) {
return detail::tapply_with_idx(
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
tuple_seq<T>{});
} else {
TORCH_CHECK(false, "Unknown Layout");
return f(t);
}

return reinterpret_cast<T*>(tensor.data_ptr());
CUTE_GCC_UNREACHABLE;
}

// calls: make_shape(f(0), f(1), ..., f(N-1))
template <int N, class F>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
return detail::make_shape_from_idx(f, make_seq<N>{});
}

template <typename T, typename Layout = RowMajor>
T* maybe_data_ptr(c10::optional<torch::Tensor const> maybe_tensor,
char const* name) {
return (maybe_tensor) ? data_ptr<T, Layout>(*maybe_tensor, name) : nullptr;
}; // namespace cute

template <typename Stride>
static inline auto make_cute_layout(torch::Tensor const& tensor,
std::string_view name = "tensor") {
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(
Stride{}, [&](auto const& stride_ele, auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;

if (tensor.dim() <= idx) {
return StrideEle{};
}

if constexpr (cute::is_static_v<StrideEle>) {
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", name,
".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
return tensor.stride(idx);
}
});

auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
if (idx < tensor.dim())
return tensor.size(idx);
else
return int64_t(1);
});

return make_layout(shape, stride);
}

template <typename Stride>
static inline auto maybe_make_cute_layout(
c10::optional<torch::Tensor> const& tensor,
std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor));

if (tensor) {
return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
} else {
return std::optional<Layout>{};
}
}

//
Expand Down
105 changes: 74 additions & 31 deletions csrc/quantization/machete/machete_mm_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,18 @@ struct MacheteKernelTemplate {
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementC = cute::conditional_t<with_C, ElementD, void>;
using ElementZero = ZeroT;
using ElementScale = ScaleT;
using ElementZ = ZeroT;
using ElementS = ScaleT;

using ElementAccumulator =
AccumulatorT; // Element type for internal accumulation
using ElementCompute = AccumulatorT; // For Epilogue

using BTypeTuple = cute::conditional_t<
with_scales,
cute::conditional_t<with_zeropoints,
cute::tuple<ElementB, ElementScale, ElementZero>,
cute::tuple<ElementB, ElementScale>>,
cute::tuple<ElementB, ElementS, ElementZ>,
cute::tuple<ElementB, ElementS>>,
ElementB>;

using LayoutA = cutlass::layout::RowMajor;
Expand All @@ -65,6 +66,13 @@ struct MacheteKernelTemplate {
// not actually used since B has the prepacked layout, but required by cutlass
using _LayoutB = cutlass::layout::ColumnMajor;

// Interface strides expected by create_arguments (will get transposed)
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>;
using StrideZ = StrideS;

using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutC_Transpose =
Expand Down Expand Up @@ -115,53 +123,88 @@ struct MacheteKernelTemplate {
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideC = typename GemmKernel::StrideC;
using StrideD = typename GemmKernel::StrideD;
using StrideS = typename CollectiveMainloop::StrideScale;

// stride_B is unused (since B is prepacked), but still required by cutlass
using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;

using Arguments = typename Gemm::Arguments;
using MainloopArguments = typename GemmKernel::MainloopArguments;
using EpilogueArguments = typename GemmKernel::EpilogueArguments;

static Arguments create_arguments(cudaStream_t stream, int M, int N, int K,
ElementA const* A, ElementB const* B,
ElementC const* C, ElementD* D,
ElementScale const* scales,
ElementZero const* zeros,
ElementCompute alpha, ElementCompute beta,
std::optional<int> maybe_group_size) {
template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
typename ShapeZ>
static Arguments create_arguments(
cudaStream_t stream,
ElementA const* A_ptr, // A is an MxK matrix
Layout<ShapeA, StrideA> const& layout_A,
ElementB const* B_ptr, // B is an KxN prepacked matrix
ElementD* D_ptr, // D is an MxN matrix
Layout<ShapeD, StrideD> const& layout_D,
ElementC const* C_ptr, // C is an MxN matrix
std::optional<Layout<ShapeC, StrideC>> const& layout_C,
ElementS const* S_ptr, // S is an scale_KxN matrix
std::optional<Layout<ShapeS, StrideS>> const& layout_S,
ElementZ const* Z_ptr, // Z is an scale_KxN matrix
std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
ElementCompute alpha, ElementCompute beta,
std::optional<int> maybe_group_size) {
static_assert(!with_zeropoints || with_scales);
TORCH_CHECK(with_C || (!with_C && beta == 0));
TORCH_CHECK(with_scales || !scales);
TORCH_CHECK(with_zeropoints || !zeros);

static int constexpr L = 1;
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);

int const group_size = maybe_group_size.value_or(K);
int const scale_k = (K + group_size - 1) / group_size;

// stride_B is unused (since B is prepacked), but still required by cutlass
auto stride_A = make_cute_stride(StrideA{}, N, K, L);
auto stride_B = make_cute_stride(_StrideB{}, M, K, L);
auto stride_C = make_cute_stride(StrideC{}, N, M, L);
auto stride_D = make_cute_stride(StrideD{}, N, M, L);
auto stride_S = make_cute_stride(StrideS{}, N, scale_k, L);
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);

if constexpr (with_C) {
TORCH_CHECK(C_ptr && layout_C);
} else {
TORCH_CHECK(!C_ptr, "C not supported");
}

if constexpr (with_scales) {
TORCH_CHECK(S_ptr && layout_S);
TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N));
} else {
TORCH_CHECK(!S_ptr, "Scales not supported");
}

if constexpr (with_zeropoints) {
TORCH_CHECK(Z_ptr && layout_Z);
TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N));
TORCH_CHECK(layout_S && *layout_Z == *layout_S,
"Scales and zeros must have the same layout");
} else {
TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
}

// Transpose A and D
// A doen't need to be tranposed since cutlass expects a NxK matrix
// for B (which is At)
auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
auto stride_Ct = stride_Dt;
if (layout_C) {
stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
}

MainloopArguments mainloop_arguments{};
EpilogueArguments epilogue_arguments{
{alpha, beta}, C, stride_C, D, stride_D};
{alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt};

if constexpr (with_scales && with_zeropoints) {
mainloop_arguments = MainloopArguments{
B, stride_B, A, stride_A, scales, stride_S, group_size, zeros};
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_ptr, stride_S, group_size, Z_ptr};
} else if constexpr (with_scales) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments = MainloopArguments{
B, stride_B, A, stride_A, scales, stride_S, group_size};
B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size};
} else {
mainloop_arguments = MainloopArguments{B, stride_B, A, stride_A};
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
}

return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
Expand Down
Loading

0 comments on commit 8b21235

Please sign in to comment.