Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 20, 2024
1 parent bd4dc71 commit a280110
Show file tree
Hide file tree
Showing 9 changed files with 385 additions and 312 deletions.
14 changes: 8 additions & 6 deletions csrc/cutlass_extensions/cute_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
else {
constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1)
stride<0>(coalesced_layout) == 1) {
return true;
}
return false;
}
}
Expand All @@ -51,16 +52,17 @@ static constexpr auto get_logical_ptr(PointerType* ptr) {
template <typename T, typename Elements>
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
constexpr auto bits = sizeof_bits_v<T> * Elements{};
if constexpr (bits % 128 == 0)
if constexpr (bits % 128 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<128>{};
else if constexpr (bits % 64 == 0)
} else if constexpr (bits % 64 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<64>{};
else if constexpr (bits % 32 == 0)
} else if constexpr (bits % 32 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<32>{};
else if constexpr (bits % 16 == 0)
} else if constexpr (bits % 16 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<16>{};
else
} else {
return AutoVectorizingCopyWithAssumedAlignment<8>{};
}
}

}; // namespace cute
7 changes: 4 additions & 3 deletions csrc/cutlass_extensions/torch_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace detail {
template <class T, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
seq<I...>) {
return g(f(get<I>(static_cast<T&&>(t)), I)...);
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
}

template <class F, int... I>
Expand All @@ -29,7 +29,7 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {

template <class T, class F>
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
if constexpr (is_tuple<T>::value) {
if constexpr (cute::is_tuple<T>::value) {
return detail::tapply_with_idx(
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
tuple_seq<T>{});
Expand Down Expand Up @@ -72,8 +72,9 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>)
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
Expand Down
2 changes: 1 addition & 1 deletion csrc/cutlass_extensions/vllm_numeric_conversion.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> {
// Below constructs the following temporary:
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
static_assert(RegArray::kElements <= 4,
"Too many inputs for BF16 -> I4 vector converter");
"Too many inputs for uint4b8_t -> BF16 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
Expand Down
4 changes: 2 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ namespace machete {
std::vector<std::string> supported_schedules(
vllm::ScalarTypeTorchPtr const& btype);

torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B,
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
Expand All @@ -97,7 +97,7 @@ torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule);

torch::Tensor prepack_B(torch::Tensor const B,
torch::Tensor prepack_B(torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype);

}; // namespace machete
Expand Down
Loading

0 comments on commit a280110

Please sign in to comment.