Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Refactor #5582

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include "utils/vector_copy_utils.h"
#include "utils/vec_copy.h"
#include "../common/micros.h"

template<typename scalar_t, bool Aligned, int VecSize>
Expand Down
2 changes: 1 addition & 1 deletion extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include "utils/vector_copy_utils.h"
#include "utils/vec_copy.h"
#include "../common/micros.h"

template<typename scalar_t, bool Aligned, int VecSize>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ namespace funcs {
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };

// Note(LiuYang): This file provides base math operation for data type
// include POD and cuda built-in type such as half and __nv_bfloat16
template <typename LT, typename RT, typename RET, BinaryOpType Op>
// include POD and cuda built-in type such as half and __nv_bfloat16.
// Implementation of common and simple binary operators should be placed here,
// otherwise, they should be placed in a new file under functors dir.
template <typename LT, typename RT, typename RET, BinaryOpType op_type>
struct BinaryOpFunctor;

#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \
Expand Down
26 changes: 0 additions & 26 deletions extensions/csrc/cuda/funcs/cast_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,6 @@ namespace colossalAI {
namespace cuda {
namespace funcs {

// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality

template <>
struct TypeConverter<half2> {
using Type = at::Half;
};

template <>
struct TypeConverter<at::Half> {
using Type = half2;
};

template <>
struct TypeConverter<__nv_bfloat162> {
using Type = at::BFloat16;
};

template <>
struct TypeConverter<at::BFloat16> {
using Type = __nv_bfloat162;
};

template <typename From, typename To>
struct CastFunctor : public std::unary_function<From, To> {
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
Expand Down
46 changes: 46 additions & 0 deletions extensions/csrc/cuda/funcs/unary_functor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <functional>

#include "../utils/micros.h"

namespace colossalAI {
namespace cuda {
namespace funcs {

// Note(LiuYang): As a retrieved table to check which operation is supported
// already
enum class UnaryOpType { kLog2Ceil = 0 };

// Note(LiuYang): Implementation of common and simple unary operators should be
// placed here, otherwise, they should be placed in a new file under functors
// dir.
template <typename From, typename To, UnaryOpType op_type>
struct UnaryOpFunctor;

#define COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION( \
FROM, TO, UNARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \
template <ARGS> \
struct UnaryOpFunctor<FROM, TO, UNARY_OP_TYPE> \
: public std::unary_function<FROM, TO> { \
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
};

COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(int, int, UnaryOpType::kLog2Ceil,
HOSTDEVICE, {
int log2_value = 0;
while ((1 << log2_value) < val)
++log2_value;
return log2_value;
})

#undef COLOSSAL_UARY_FUNCTOR_SPECIALIZATION

} // namespace funcs
} // namespace cuda
} // namespace colossalAI
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include "utils/vector_copy_utils.h"
#include "utils/vec_copy.h"
#include "../common/micros.h"
#include "../common/mp_type_traits.h"

Expand Down
2 changes: 1 addition & 1 deletion extensions/csrc/cuda/get_cos_and_sin_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include "utils/vector_copy_utils.h"
#include "utils/vec_copy.h"
#include "../common/micros.h"
#include "stdio.h"

Expand Down
62 changes: 29 additions & 33 deletions extensions/csrc/cuda/include/block_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include "../funcs/op_functor.h"
#include "../funcs/binary_functor.h"

namespace colossalAI {
namespace cuda {
namespace utils {

const float kReduceFloatInfNeg = -100000000.f;
const float kReduceFloatInfPos = 100000000.f;
const int kWarpSize = 32;
const unsigned int kWarpReduceMask = 0xffffffff;

enum class ReduceType { kMax = 0, kSum };
Expand All @@ -31,44 +30,42 @@ struct GetOpForReduceType<T, ReduceType::kSum> {
};

#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
for (int offset = 0; offset < LANES; ++offset) { \
_Pragma("unroll") for (int offset = 0; offset < LANES; ++offset) { \
*(VAL_PTR + offset) = \
OP(*(VAL_PTR + offset), \
__shfl_xor_sync(MASK, *(VAL_PTR + offset), DELTA, WIDTH)); \
}

#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, OP, LANES) \
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 16, 32, OP, LANES) \
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 8, 32, OP, LANES) \
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 4, 32, OP, LANES) \
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 2, 32, OP, LANES) \
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 1, 32, OP, LANES)

#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, MASK, VAL_PTR, OP, LANES, \
DEFAULT_VALUE, REDUCE_TYPE) \
__shared__ T shm[LANES][32]; \
int lane_id = threadIdx.x & 0x1f; \
int warp_id = threadIdx.x >> 5; \
\
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR); \
if (lane_id == 0) { \
for (int offset = 0; offset < LANES; ++offset) { \
shm[offset][warp_id] = *(VAL_PTR + offset); \
} \
} \
__syncthreads(); \
\
for (int offset = 0; offset < LANES; ++offset) { \
*(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \
? shm[offset][lane_id] \
: static_cast<T>(DEFAULT_VALUE); \
} \
#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, WIDTH, OP, LANES) \
_Pragma("unroll") for (int DELTA = (WIDTH >> 1); DELTA > 0; DELTA >>= 1) { \
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
}

#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, VAL_PTR, OP, LANES, DEFAULT_VALUE, \
REDUCE_TYPE) \
__shared__ T shm[LANES][32]; \
int lane_id = threadIdx.x & 0x1f; \
int warp_id = threadIdx.x >> 5; \
\
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR); \
if (lane_id == 0) { \
for (int offset = 0; offset < LANES; ++offset) { \
shm[offset][warp_id] = *(VAL_PTR + offset); \
} \
} \
__syncthreads(); \
\
_Pragma("unroll") for (int offset = 0; offset < LANES; ++offset) { \
*(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \
? shm[offset][lane_id] \
: static_cast<T>(DEFAULT_VALUE); \
} \
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR);

template <typename T, ReduceType rtype, int lanes>
template <typename T, ReduceType rtype, int lanes, int width = 32>
__forceinline__ __device__ void warp_reduce(T* pval) {
typename GetOpForReduceType<T, rtype>::Op op;
COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, op, lanes);
COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, width, op, lanes);
}

template <typename T, ReduceType rtype>
Expand All @@ -84,8 +81,7 @@ template <typename T, ReduceType rtype, int lanes>
__forceinline__ __device__ void block_reduce(T* pval) {
constexpr T kDefaultValue = GetDefaultValueForBlockReduce<T, rtype>();
typename GetOpForReduceType<T, rtype>::Op op;
COLOSSAL_BLOCK_REDUCE_IMPL(T, kWarpReduceMask, pval, op, lanes, kDefaultValue,
rtype);
COLOSSAL_BLOCK_REDUCE_IMPL(T, pval, op, lanes, kDefaultValue, rtype);
}

#undef COLOSSAL_SHFL_FUNCTION
Expand Down
26 changes: 5 additions & 21 deletions extensions/csrc/cuda/pybind/scaled_masked_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,15 @@

#include <vector>

namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {

torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
float scale_factor);

torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);

int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
int attn_heads);
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
int attn_heads);

torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask,
float scale_factor) {
Expand Down Expand Up @@ -46,25 +42,13 @@ torch::Tensor bwd(torch::Tensor const& output_grads,
return bwd_cuda(output_grads, softmax_results, scale_factor);
}

int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches,
attn_heads);
}

} // end namespace scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
m.def("forward", &fwd,
Courtesy-Xs marked this conversation as resolved.
Show resolved Hide resolved
"Self Multihead Attention scaled, time masked softmax -- Forward.");

m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
m.def("backward", &bwd,
Courtesy-Xs marked this conversation as resolved.
Show resolved Hide resolved
"Self Multihead Attention scaled, time masked softmax -- Backward.");

m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax::
get_batch_per_block,
m.def("get_batch_per_block", &get_batch_per_block,
"Return Batch per block size.");
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@

#include <vector>

namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {

torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);

torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
Expand Down Expand Up @@ -40,15 +36,9 @@ torch::Tensor bwd(torch::Tensor const& output_grads,
return bwd_cuda(output_grads, softmax_results, scale_factor);
}

} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
m.def("forward", &fwd,
Courtesy-Xs marked this conversation as resolved.
Show resolved Hide resolved
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
m.def("backward", &bwd,
Courtesy-Xs marked this conversation as resolved.
Show resolved Hide resolved
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
Loading