diff --git a/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu index 3300fad47796..b45daea47504 100644 --- a/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu @@ -1,7 +1,7 @@ #include #include -#include "utils/vector_copy_utils.h" +#include "utils/vec_copy.h" #include "../common/micros.h" template diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu index 3fcceac6b942..e0cfbbed7505 100644 --- a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -1,7 +1,7 @@ #include #include -#include "utils/vector_copy_utils.h" +#include "utils/vec_copy.h" #include "../common/micros.h" template diff --git a/extensions/csrc/cuda/funcs/op_functor.h b/extensions/csrc/cuda/funcs/binary_functor.h similarity index 94% rename from extensions/csrc/cuda/funcs/op_functor.h rename to extensions/csrc/cuda/funcs/binary_functor.h index 0398ea97b539..2f26e71977b1 100644 --- a/extensions/csrc/cuda/funcs/op_functor.h +++ b/extensions/csrc/cuda/funcs/binary_functor.h @@ -16,8 +16,10 @@ namespace funcs { enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin }; // Note(LiuYang): This file provides base math operation for data type -// include POD and cuda built-in type such as half and __nv_bfloat16 -template +// include POD and cuda built-in type such as half and __nv_bfloat16. +// Implementation of common and simple binary operators should be placed here, +// otherwise, they should be placed in a new file under functors dir. +template struct BinaryOpFunctor; #define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \ diff --git a/extensions/csrc/cuda/funcs/cast_functor.h b/extensions/csrc/cuda/funcs/cast_functor.h index 623e1cdeb290..dbb7195d099a 100644 --- a/extensions/csrc/cuda/funcs/cast_functor.h +++ b/extensions/csrc/cuda/funcs/cast_functor.h @@ -16,32 +16,6 @@ namespace colossalAI { namespace cuda { namespace funcs { -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = at::Half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = at::BFloat16; -}; - -template <> -struct TypeConverter { - using Type = __nv_bfloat162; -}; - template struct CastFunctor : public std::unary_function { HOSTDEVICE To operator()(From val) { return static_cast(val); } diff --git a/extensions/csrc/cuda/funcs/unary_functor.h b/extensions/csrc/cuda/funcs/unary_functor.h new file mode 100644 index 000000000000..72c421ea1cbe --- /dev/null +++ b/extensions/csrc/cuda/funcs/unary_functor.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include "../utils/micros.h" + +namespace colossalAI { +namespace cuda { +namespace funcs { + +// Note(LiuYang): As a retrieved table to check which operation is supported +// already +enum class UnaryOpType { kLog2Ceil = 0 }; + +// Note(LiuYang): Implementation of common and simple unary operators should be +// placed here, otherwise, they should be placed in a new file under functors +// dir. +template +struct UnaryOpFunctor; + +#define COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION( \ + FROM, TO, UNARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \ + template \ + struct UnaryOpFunctor \ + : public std::unary_function { \ + FUNCTION_MODIFIER TO operator()(FROM val) STMTS \ + }; + +COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(int, int, UnaryOpType::kLog2Ceil, + HOSTDEVICE, { + int log2_value = 0; + while ((1 << log2_value) < val) + ++log2_value; + return log2_value; + }) + +#undef COLOSSAL_UARY_FUNCTOR_SPECIALIZATION + +} // namespace funcs +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu index 8feb6b343620..e5766e981167 100644 --- a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu +++ b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -2,7 +2,7 @@ #include #include -#include "utils/vector_copy_utils.h" +#include "utils/vec_copy.h" #include "../common/micros.h" #include "../common/mp_type_traits.h" diff --git a/extensions/csrc/cuda/get_cos_and_sin_kernel.cu b/extensions/csrc/cuda/get_cos_and_sin_kernel.cu index 15aea740e6f9..15b5c5efbdcf 100644 --- a/extensions/csrc/cuda/get_cos_and_sin_kernel.cu +++ b/extensions/csrc/cuda/get_cos_and_sin_kernel.cu @@ -1,7 +1,7 @@ #include #include -#include "utils/vector_copy_utils.h" +#include "utils/vec_copy.h" #include "../common/micros.h" #include "stdio.h" diff --git a/extensions/csrc/cuda/include/block_reduce.h b/extensions/csrc/cuda/include/block_reduce.h index 6f6db6f774ab..a9bd537f7cba 100644 --- a/extensions/csrc/cuda/include/block_reduce.h +++ b/extensions/csrc/cuda/include/block_reduce.h @@ -4,7 +4,7 @@ #include #include -#include "../funcs/op_functor.h" +#include "../funcs/binary_functor.h" namespace colossalAI { namespace cuda { @@ -12,7 +12,6 @@ namespace utils { const float kReduceFloatInfNeg = -100000000.f; const float kReduceFloatInfPos = 100000000.f; -const int kWarpSize = 32; const unsigned int kWarpReduceMask = 0xffffffff; enum class ReduceType { kMax = 0, kSum }; @@ -31,44 +30,42 @@ struct GetOpForReduceType { }; #define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \ - for (int offset = 0; offset < LANES; ++offset) { \ + _Pragma("unroll") for (int offset = 0; offset < LANES; ++offset) { \ *(VAL_PTR + offset) = \ OP(*(VAL_PTR + offset), \ __shfl_xor_sync(MASK, *(VAL_PTR + offset), DELTA, WIDTH)); \ } -#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, OP, LANES) \ - COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 16, 32, OP, LANES) \ - COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 8, 32, OP, LANES) \ - COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 4, 32, OP, LANES) \ - COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 2, 32, OP, LANES) \ - COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 1, 32, OP, LANES) - -#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, MASK, VAL_PTR, OP, LANES, \ - DEFAULT_VALUE, REDUCE_TYPE) \ - __shared__ T shm[LANES][32]; \ - int lane_id = threadIdx.x & 0x1f; \ - int warp_id = threadIdx.x >> 5; \ - \ - warp_reduce(VAL_PTR); \ - if (lane_id == 0) { \ - for (int offset = 0; offset < LANES; ++offset) { \ - shm[offset][warp_id] = *(VAL_PTR + offset); \ - } \ - } \ - __syncthreads(); \ - \ - for (int offset = 0; offset < LANES; ++offset) { \ - *(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \ - ? shm[offset][lane_id] \ - : static_cast(DEFAULT_VALUE); \ - } \ +#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, WIDTH, OP, LANES) \ + _Pragma("unroll") for (int DELTA = (WIDTH >> 1); DELTA > 0; DELTA >>= 1) { \ + COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \ + } + +#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, VAL_PTR, OP, LANES, DEFAULT_VALUE, \ + REDUCE_TYPE) \ + __shared__ T shm[LANES][32]; \ + int lane_id = threadIdx.x & 0x1f; \ + int warp_id = threadIdx.x >> 5; \ + \ + warp_reduce(VAL_PTR); \ + if (lane_id == 0) { \ + for (int offset = 0; offset < LANES; ++offset) { \ + shm[offset][warp_id] = *(VAL_PTR + offset); \ + } \ + } \ + __syncthreads(); \ + \ + _Pragma("unroll") for (int offset = 0; offset < LANES; ++offset) { \ + *(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \ + ? shm[offset][lane_id] \ + : static_cast(DEFAULT_VALUE); \ + } \ warp_reduce(VAL_PTR); -template +template __forceinline__ __device__ void warp_reduce(T* pval) { typename GetOpForReduceType::Op op; - COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, op, lanes); + COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, width, op, lanes); } template @@ -84,8 +81,7 @@ template __forceinline__ __device__ void block_reduce(T* pval) { constexpr T kDefaultValue = GetDefaultValueForBlockReduce(); typename GetOpForReduceType::Op op; - COLOSSAL_BLOCK_REDUCE_IMPL(T, kWarpReduceMask, pval, op, lanes, kDefaultValue, - rtype); + COLOSSAL_BLOCK_REDUCE_IMPL(T, pval, op, lanes, kDefaultValue, rtype); } #undef COLOSSAL_SHFL_FUNCTION diff --git a/extensions/csrc/cuda/pybind/scaled_masked_softmax.cpp b/extensions/csrc/cuda/pybind/scaled_masked_softmax.cpp index 8c2982b0cff9..427035d4e88b 100644 --- a/extensions/csrc/cuda/pybind/scaled_masked_softmax.cpp +++ b/extensions/csrc/cuda/pybind/scaled_masked_softmax.cpp @@ -6,10 +6,6 @@ #include -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor); @@ -17,8 +13,8 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, - int attn_heads); +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads); torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) { @@ -46,25 +42,13 @@ torch::Tensor bwd(torch::Tensor const& output_grads, return bwd_cuda(output_grads, softmax_results, scale_factor); } -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, - attn_heads); -} - -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + m.def("forward", &fwd, "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + m.def("backward", &bwd, "Self Multihead Attention scaled, time masked softmax -- Backward."); - m.def("get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax:: - get_batch_per_block, + m.def("get_batch_per_block", &get_batch_per_block, "Return Batch per block size."); } diff --git a/extensions/csrc/cuda/pybind/scaled_upper_triang_masked_softmax.cpp b/extensions/csrc/cuda/pybind/scaled_upper_triang_masked_softmax.cpp index cbbc3706497a..bbd65712374d 100644 --- a/extensions/csrc/cuda/pybind/scaled_upper_triang_masked_softmax.cpp +++ b/extensions/csrc/cuda/pybind/scaled_upper_triang_masked_softmax.cpp @@ -6,10 +6,6 @@ #include -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); torch::Tensor bwd_cuda(torch::Tensor const& output_grads, @@ -40,15 +36,9 @@ torch::Tensor bwd(torch::Tensor const& output_grads, return bwd_cuda(output_grads, softmax_results, scale_factor); } -} // end namespace scaled_upper_triang_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + m.def("forward", &fwd, "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + m.def("backward", &bwd, "Self Multihead Attention scaled, time masked softmax -- Backward."); } diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index c39e44d8725f..33f35ccbd550 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -11,42 +11,33 @@ #include "block_reduce.h" #include "../common/micros.h" #include "funcs/cast_functor.h" -#include "funcs/op_functor.h" +#include "funcs/binary_functor.h" using colossalAI::cuda::utils::block_reduce; using colossalAI::cuda::utils::ReduceType; -using colossalAI::cuda::funcs::TypeConverter; using colossalAI::cuda::funcs::CastFunctor; using colossalAI::cuda::funcs::BinaryOpFunctor; using colossalAI::cuda::funcs::BinaryOpType; -#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \ - if (DATA_SIZE == 2) { \ - switch (TYPE) { \ - case at::ScalarType::Half: { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } \ - } else { \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t = float; \ - general_##__VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } \ - } \ + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; + +#define TYPE_CONVERTER_SPECIALIZATION(FROM, TO) \ + template <> \ + struct TypeConverter { \ + using Type = TO; \ + }; + +TYPE_CONVERTER_SPECIALIZATION(half2, at::Half) +TYPE_CONVERTER_SPECIALIZATION(at::Half, half2) +TYPE_CONVERTER_SPECIALIZATION(__nv_bfloat162, at::BFloat16) +TYPE_CONVERTER_SPECIALIZATION(at::BFloat16, __nv_bfloat162) + +#undef TYPE_CONVERTER_SPECIALIZATION // optimized for half and bf16 template @@ -217,6 +208,36 @@ __global__ void general_fused_add_rms_layernorm_kernel( } } + +#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \ + if (DATA_SIZE == 2) { \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ + } else { \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + general_##__VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ + } \ + + void rms_layernorm( torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] @@ -424,3 +445,5 @@ void fused_add_rms_layernorm( } } } + +#undef DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT diff --git a/extensions/csrc/cuda/scaled_masked_softmax.h b/extensions/csrc/cuda/scaled_masked_softmax.h deleted file mode 100644 index cbbe7f36ad38..000000000000 --- a/extensions/csrc/cuda/scaled_masked_softmax.h +++ /dev/null @@ -1,500 +0,0 @@ -/*This code from NVIDIA Megatron: - * with minor changes. */ - -#pragma once - -#include -#include -#include - -#include -#include - -#include "utils/vector_copy_utils.h" - -namespace { - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T -WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, - unsigned int mask = 0xffffffff) { -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t *sum) { - ReduceOp r; -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional - * features 1) input scaling 2) Explicit masking - */ -template -__global__ void scaled_masked_softmax_warp_forward( - output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale, - int micro_batch_size, int element_count, int pad_batches) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = - (blockDim.y * - (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + - threadIdx.y) * - WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = - (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * - WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; - - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i * element_count + it * WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); - -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH]{0.0f}; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; - } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } - } - } -} - -template -__global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, - int micro_batch_size, int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = - first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector( - temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector( - temp_output, output + i * element_count + it * WARP_SIZE); - -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = - (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - - acc_t sum[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = - (output_t)(scale * (grad_reg[i][it + element] - - output_reg[i][it + element] * sum[i])); - } - copy_vector( - gradInput + i * element_count + it * WARP_SIZE, out); - } - } - } -} -} // end of anonymous namespace - -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, - int attn_heads) { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - - return batches_per_block; -} - -template -void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, int key_seq_len, - int batches, int attn_heads, - int pad_batches) { - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside - // softmax_warp_forward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside - // softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0); - dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_masked_softmax_backward(output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, int key_seq_len, - int batches, int attn_heads) { - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside - // softmax_warp_backward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside - // softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: - break; - } - } -} diff --git a/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu b/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu index 2f968d30f106..e0bb6497abae 100644 --- a/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu +++ b/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu @@ -9,16 +9,462 @@ #include #include -#include "scaled_masked_softmax.h" +#include +#include +#include +#include + #include "../common/micros.h" +#include "utils/vec_copy.h" +#include "include/block_reduce.h" +#include "funcs/unary_functor.h" + +using colossalAI::cuda::funcs::UnaryOpFunctor; +using colossalAI::cuda::funcs::UnaryOpType; +using colossalAI::cuda::utils::warp_reduce; +using colossalAI::cuda::utils::ReduceType; + + +/* + * Extended softmax (from native aten pytorch) with following additional + * features 1) input scaling 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward( + output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale, + int micro_batch_size, int element_count, int pad_batches) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = + (blockDim.y * + (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * + WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i * element_count + it * WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector( + dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +template +__global__ void scaled_masked_softmax_warp_backward( + output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, + int micro_batch_size, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = + first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = + (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); + } + copy_vector( + gradInput + i * element_count + it * WARP_SIZE, out); + } + } + } +} + + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + int log2_elements = UnaryOpFunctor()(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + +template +void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, int key_seq_len, + int batches, int attn_heads, + int pad_batches) { + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = UnaryOpFunctor()(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_forward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward(output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, int key_seq_len, + int batches, int attn_heads) { + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = UnaryOpFunctor()(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, - int attn_heads) { - return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; + } + } } torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, @@ -84,6 +530,3 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, // backward pass is completely in-place return output_grads; } -} // namespace scaled_masked_softmax -} // namespace fused_softmax -} // namespace multihead_attn diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h deleted file mode 100644 index bd2465beabd2..000000000000 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h +++ /dev/null @@ -1,538 +0,0 @@ -/*This code from NVIDIA Megatron: - * with minor changes. */ - -#pragma once - -#include -#include -#include -#include - -#include -#include - -#include "utils/vector_copy_utils.h" - -namespace { - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T -WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, - unsigned int mask = 0xffffffff) { -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t *sum) { - ReduceOp r; -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional - * features 1) input scaling 2) Implicit time (diagonal masking) - */ -template -__global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size, - int stride, int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = - (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + - blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = - (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; - - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - copy_vector( - temp_data, src + i * element_count * stride + it * WARP_SIZE); - -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } else { -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH]{0.0f}; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < local_seq) { -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0; - } - } - copy_vector( - dst + i * element_count * stride + it * WARP_SIZE, out); - } else if (element_index < element_count) { - copy_zero_vector( - dst + i * element_count * stride + it * WARP_SIZE); - } else { - break; - } - } - } -} - -template -__global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, - int micro_batch_size, int stride, int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = - (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + - blockIdx.x; - int local_seq = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector( - temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector( - temp_output, output + i * element_count * stride + it * WARP_SIZE); - -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - } -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = - (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - } - - acc_t sum[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = - (output_t)(scale * (grad_reg[i][it + element] - - output_reg[i][it + element] * sum[i])); - } - copy_vector( - gradInput + i * element_count * stride + it * WARP_SIZE, out); - } - } - } -} - -} // end of anonymous namespace - -template -void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, const input_t *src, const input_t scale, - int softmax_elements, int softmax_elements_stride, int attn_batches) { - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside - // softmax_warp_forward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside - // softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, input_t *grad, const input_t *output, - const acc_t scale, int softmax_elements, int softmax_elements_stride, - int attn_batches) { - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside - // softmax_warp_backward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside - // softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu index d9550dc2c2a5..d44097b6b9b4 100644 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu +++ b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu @@ -8,13 +8,502 @@ #include #include #include +#include +#include +#include +#include +#include -#include "scaled_upper_triang_masked_softmax.h" #include "../common/micros.h" +#include "utils/vec_copy.h" +#include "include/block_reduce.h" +#include "funcs/unary_functor.h" + +using colossalAI::cuda::funcs::UnaryOpFunctor; +using colossalAI::cuda::funcs::UnaryOpType; +using colossalAI::cuda::utils::warp_reduce; +using colossalAI::cuda::utils::ReduceType; + +/* + * Extended softmax (from native aten pytorch) with following additional + * features 1) input scaling 2) Implicit time (diagonal masking) + */ +template +__global__ void scaled_upper_triang_masked_softmax_warp_forward( + output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size, + int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = + (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector( + temp_data, src + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } + } + copy_vector( + dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector( + dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } + } + } +} + +template +__global__ void scaled_upper_triang_masked_softmax_warp_backward( + output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, + int micro_batch_size, int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = + (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); + } + copy_vector( + gradInput + i * element_count * stride + it * WARP_SIZE, out); + } + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_forward( + output_t *dst, const input_t *src, const input_t scale, + int softmax_elements, int softmax_elements_stride, int attn_batches) { + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = UnaryOpFunctor()(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_forward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_backward( + output_t *grad_input, input_t *grad, const input_t *output, + const acc_t scale, int softmax_elements, int softmax_elements_stride, + int attn_batches) { + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = UnaryOpFunctor()(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + + -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] @@ -70,6 +559,3 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, // backward pass is completely in-place return output_grads; } -} // namespace scaled_upper_triang_masked_softmax -} // namespace fused_softmax -} // namespace multihead_attn diff --git a/extensions/csrc/cuda/utils/vector_copy_utils.h b/extensions/csrc/cuda/utils/vec_copy.h similarity index 100% rename from extensions/csrc/cuda/utils/vector_copy_utils.h rename to extensions/csrc/cuda/utils/vec_copy.h diff --git a/extensions/csrc/cuda/utils/vec_type_traits.h b/extensions/csrc/cuda/utils/vec_type_traits.h index 3ddd64df95fd..0bd25469a923 100644 --- a/extensions/csrc/cuda/utils/vec_type_traits.h +++ b/extensions/csrc/cuda/utils/vec_type_traits.h @@ -13,70 +13,27 @@ namespace utils { template struct VecTypeTrait {}; -template -struct VecTypeTrait { - using Type = T; -}; - -template <> -struct VecTypeTrait { - using Type = float; -}; - -template <> -struct VecTypeTrait { - using Type = float2; -}; - -template <> -struct VecTypeTrait { - using Type = float4; -}; - -template <> -struct VecTypeTrait { - using Type = float; -}; - -template <> -struct VecTypeTrait { - using Type = float2; -}; - -template <> -struct VecTypeTrait { - using Type = float4; -}; - -template <> -struct VecTypeTrait { - using Type = float2; -}; - -template <> -struct VecTypeTrait { - using Type = float4; -}; - -template <> -struct VecTypeTrait { - using Type = float4; -}; - -template <> -struct VecTypeTrait { - using Type = half; -}; - -template <> -struct VecTypeTrait { - using Type = half2; -}; - -template <> -struct VecTypeTrait { - using Type = float2; -}; +#define VEC_TYPE_TRAITS_SPECIALIZATION(T, VEC_SIZE, VECT, ARGS...) \ + template \ + struct VecTypeTrait { \ + using Type = VECT; \ + }; + +VEC_TYPE_TRAITS_SPECIALIZATION(T, 1, T, typename T) +VEC_TYPE_TRAITS_SPECIALIZATION(c10::BFloat16, 2, float) +VEC_TYPE_TRAITS_SPECIALIZATION(c10::BFloat16, 4, float2) +VEC_TYPE_TRAITS_SPECIALIZATION(c10::BFloat16, 8, float4) +VEC_TYPE_TRAITS_SPECIALIZATION(c10::Half, 2, float) +VEC_TYPE_TRAITS_SPECIALIZATION(c10::Half, 4, float2) +VEC_TYPE_TRAITS_SPECIALIZATION(c10::Half, 8, float4) +VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2) +VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4) +VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, float4) +VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, half) +VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, half2) +VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, float2) + +#undef VEC_TYPE_TRAITS_SPECIALIZATION } // namespace utils } // namespace cuda