From ec13da6ba7cabc44bb4745a64a208b8580792954 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 13 Oct 2022 16:47:12 -0700 Subject: [PATCH] add SD injection policy (#2381) Co-authored-by: Reza Yazdani Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> --- csrc/transformer/inference/csrc/gelu.cu | 118 +++++++ .../transformer/inference/csrc/pt_binding.cpp | 195 +++++++++++- csrc/transformer/inference/csrc/softmax.cu | 5 +- csrc/transformer/inference/csrc/transform.cu | 100 ++++++ .../inference/includes/inference_context.h | 19 +- .../includes/inference_cuda_layers.h | 28 ++ deepspeed/inference/engine.py | 74 +++-- deepspeed/module_inject/encoder.py | 66 ++++ deepspeed/module_inject/replace_module.py | 90 +++++- deepspeed/module_inject/replace_policy.py | 133 +++++++- deepspeed/module_inject/unet.py | 82 +++++ deepspeed/moe/utils.py | 1 + deepspeed/ops/transformer/__init__.py | 1 + .../ops/transformer/inference/__init__.py | 1 + .../ops/transformer/inference/attention.py | 296 ++++++++++++++++++ .../inference/transformer_inference.py | 5 +- .../ops/transformer/inference/triton_ops.py | 151 +++++++++ requirements/requirements-sd.txt | 2 + 18 files changed, 1305 insertions(+), 62 deletions(-) create mode 100644 deepspeed/module_inject/encoder.py create mode 100644 deepspeed/module_inject/unet.py create mode 100644 deepspeed/ops/transformer/inference/attention.py create mode 100644 deepspeed/ops/transformer/inference/triton_ops.py create mode 100644 requirements/requirements-sd.txt diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index e648e13095ef..cab8eb3fe63f 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -475,3 +475,121 @@ template void launch_moe_res_matmul(__half* residual, int seq_len, int hidden_dim, cudaStream_t stream); + +__global__ void pad_data_kernel(__half* padded_output, + __half* output, + int head_size, + int padded_head_size) +{ + float4* padded_output_cast = reinterpret_cast(padded_output); + float4* output_cast = reinterpret_cast(output); + int bid = blockIdx.x * (blockDim.y) + threadIdx.y; + int idx = threadIdx.x; + padded_output_cast += (bid * padded_head_size); + output_cast += (bid * head_size); + float4 ZERO; + const __half2 zero_h = __float2half2_rn(0.f); + __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); +#pragma unroll + for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; + if (idx < head_size) + padded_output_cast[idx] = output_cast[idx]; + else + padded_output_cast[idx] = ZERO; +} +__global__ void pad_data_kernel(float* padded_output, + float* output, + int head_size, + int padded_head_size) +{ +} +template +void pad_data(T* padded_output, + T* output, + int bsz, + int head_size, + int padded_head_size, + cudaStream_t stream) +{ + dim3 grid_dim((bsz - 1) / 16 + 1); + dim3 block_dim(padded_head_size / 8, 16); + pad_data_kernel<<>>( + padded_output, output, head_size / 8, padded_head_size / 8); +} +template void pad_data(__half* padded_output, + __half* output, + int bsz, + int head_size, + int padded_head_size, + cudaStream_t stream); +template void pad_data(float* padded_output, + float* output, + int bsz, + int head_size, + int padded_head_size, + cudaStream_t stream); + +__global__ void pad_head_seq_kernel(__half* padded_output, + __half* output, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size) +{ + float4* padded_output_cast = reinterpret_cast(padded_output); + float4* output_cast = reinterpret_cast(output); + int bsz = blockIdx.x; + int bid = blockIdx.y * (blockDim.y) + threadIdx.y; + int idx = threadIdx.x; + padded_output_cast += (bsz * padded_seq_len + bid) * padded_head_size; + output_cast += (bsz * seq_len + bid) * head_size; + float4 ZERO; + const __half2 zero_h = __float2half2_rn(0.f); + __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); +#pragma unroll + for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; + + if (idx < head_size && bid < seq_len) + padded_output_cast[idx] = output_cast[idx]; + else + padded_output_cast[idx] = ZERO; +} +__global__ void pad_head_seq_kernel(float* padded_output, + float* output, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size) +{ +} +template +void pad_head_seq(T* padded_output, + T* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream) +{ + dim3 grid_dim(bsz, padded_seq_len / 16); + dim3 block_dim(padded_head_size / 8, 16); + pad_head_seq_kernel<<>>( + padded_output, output, seq_len, padded_seq_len, head_size / 8, padded_head_size / 8); +} +template void pad_head_seq(__half* padded_output, + __half* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream); +template void pad_head_seq(float* padded_output, + float* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream); diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index fc3fa9108138..f09dfc569ed2 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -831,6 +831,10 @@ template at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, at::Tensor& bias, + bool add_bias, + bool external_cache, + bool do_flash_attn, + int num_heads, unsigned num_layers) { auto input_cont = input.contiguous(); @@ -840,8 +844,23 @@ at::Tensor ds_linear_layer(at::Tensor& input, .device(at::kCUDA) .requires_grad(false); + int head_size = input_cont.size(2) / num_heads; int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); + // Reallocate memory if we received a new prompt + if (!workspace) { + cublasSetStream(Context::Instance().GetCublasHandle(), + Context::Instance().GetCurrentStream()); + allocate_workspace(input.size(2), + input.size(0), + input.size(1), + num_layers, + num_heads, + 1, + external_cache, + 0); + workspace = (T*)Context::Instance().GetWorkSpace(); + } auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); float alpha = (T)1.0; @@ -864,16 +883,172 @@ at::Tensor ds_linear_layer(at::Tensor& input, #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif + if (add_bias) + launch_bias_add((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0); + if (do_flash_attn) { + if (add_padding) { + int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); + auto padded_output = workspace + output.numel(); + auto final_output = + padded_output + (input.size(0) * input.size(1) * 3 * num_heads * padded_head_size); + pad_data(padded_output, + workspace, + 3 * bsz * num_heads, + head_size, + padded_head_size, + Context::Instance().GetCurrentStream()); - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); + launch_bias_add_transform_0213( + final_output, + final_output + (input.size(0) * input.size(1) * num_heads * padded_head_size), + final_output + (input.size(0) * input.size(1) * 2 * num_heads * padded_head_size), + padded_output, + nullptr, + input.size(0), + input.size(1), + 0, + input.size(1), + (num_heads * padded_head_size), + num_heads, + -1, + false, + false, + Context::Instance().GetCurrentStream(), + 3, + input.size(1)); + return at::from_blob(final_output, + {3, input.size(0), num_heads, input.size(1), padded_head_size}, + options); + // return at::from_blob(padded_output, {input.size(0) * input.size(1), 3, num_heads, + // padded_head_size}, options); + } else { + auto final_output = workspace + output.numel(); + launch_bias_add_transform_0213( + final_output, + final_output + (input.size(0) * input.size(1) * input_cont.size(2)), + final_output + (input.size(0) * input.size(1) * 2 * input_cont.size(2)), + workspace, + nullptr, + input.size(0), + input.size(1), + 0, + input.size(1), + input_cont.size(2), + num_heads, + -1, + false, + false, + Context::Instance().GetCurrentStream(), + 3, + input.size(1)); + return at::from_blob( + final_output, {3, input.size(0), num_heads, input.size(1), head_size}, options); + // return at::from_blob(workspace, {input.size(0) * input.size(1), 3, num_heads, + // head_size}, options); + } + + } else + return output; +} - return output; +template +std::vector add_padding(at::Tensor& query, at::Tensor& key, at::Tensor& value) +{ + int head_size = query.size(3); + int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); + T* workspace = (T*)Context::Instance().GetWorkSpace(); + T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2); + T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128; + pad_head_seq(workspace, + (T*)query.data_ptr(), + query.size(0) * query.size(1), + query.size(2), + query.size(2), + head_size, + padded_head_size, + Context::Instance().GetCurrentStream()); + pad_head_seq(key_pad_ptr, + (T*)key.data_ptr(), + query.size(0) * query.size(1), + key.size(2), + 128, + head_size, + padded_head_size, + Context::Instance().GetCurrentStream()); + pad_head_seq(value_pad_ptr, + (T*)value.data_ptr(), + query.size(0) * query.size(1), + key.size(2), + 128, + head_size, + padded_head_size, + Context::Instance().GetCurrentStream()); + return { + at::from_blob(workspace, + {query.size(0), query.size(1), query.size(2), padded_head_size}, + query.options()), + at::from_blob( + key_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options()), + at::from_blob( + value_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options())}; } +template +std::vector padd_add_transform(at::Tensor& query, + at::Tensor& key, + at::Tensor& value, + int heads, + bool add_padding) +{ + int head_size = query.size(2) / heads; + int key_value_length = add_padding ? 128 : key.size(1); + int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128)) + : head_size; + T* workspace = (T*)Context::Instance().GetWorkSpace(); + T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1); + T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length; + launch_pad_add_transform_0213(workspace, + (T*)query.data_ptr(), + query.size(0), + query.size(2), + query.size(1), + query.size(1), + heads, + padded_head_size, + Context::Instance().GetCurrentStream()); + launch_pad_add_transform_0213(key_pad_ptr, + (T*)key.data_ptr(), + key.size(0), + key.size(2), + key.size(1), + key_value_length, + heads, + padded_head_size, + Context::Instance().GetCurrentStream()); + launch_pad_add_transform_0213(value_pad_ptr, + (T*)value.data_ptr(), + value.size(0), + value.size(2), + value.size(1), + key_value_length, + heads, + padded_head_size, + Context::Instance().GetCurrentStream()); + return { + at::from_blob( + workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()), + at::from_blob(key_pad_ptr, + {query.size(0), heads, key_value_length, padded_head_size}, + query.options()), + at::from_blob(value_pad_ptr, + {query.size(0), heads, key_value_length, padded_head_size}, + query.options())}; +} template at::Tensor ds_linear_layer_int8(at::Tensor& input, at::Tensor& weight, @@ -1414,6 +1589,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) &einsum_sec_sm_ecm<__half>, "DeepSpeed vector-MM with fp16 (CUDA)"); m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)"); + m.def("add_padding_fp32", &add_padding, "DeepSpeed residual add with fp32 (CUDA)"); + m.def("add_padding_fp16", &add_padding<__half>, "DeepSpeed residual add with fp16 (CUDA)"); + m.def("pad_transform_fp32", + &padd_add_transform, + "DeepSpeed residual add with fp32 (CUDA)"); + m.def("pad_transform_fp16", + &padd_add_transform<__half>, + "DeepSpeed residual add with fp16 (CUDA)"); m.def("allocate_workspace_fp32", &allocate_workspace, "DeepSpeed memory allocation for GPT inference with fp32 (CUDA)"); diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu index ce7c2e77759d..b85ac1eb0be8 100644 --- a/csrc/transformer/inference/csrc/softmax.cu +++ b/csrc/transformer/inference/csrc/softmax.cu @@ -12,7 +12,7 @@ Copyright 2022 The Microsoft DeepSpeed Team #include #include -#define ATTN_THREADS 1024 +#define ATTN_THREADS 256 #define MAX_REG_SIZE 8 #define minus_infinity -10000.0 @@ -427,7 +427,8 @@ void launch_attn_softmax_v2(T* vals, cudaStream_t stream) { int total_count = batch_size * heads * num_seq; - dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1); + int warp_num = ATTN_THREADS / WARP_SIZE; + dim3 grid_dim((total_count - 1) / (warp_num / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1); dim3 block_dim(ATTN_THREADS); const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE; diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index 32d2df95be63..a5a43c364ed6 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -249,6 +249,106 @@ void launch_bias_add_transform_0213<__half>(__half* output, max_out_tokens); } +// Bias add + +__global__ void pad_add_transform_0213(float* output, + const float* vals, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size) +{ +} + +__global__ void pad_add_transform_0213(__half* output, + const __half* vals, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size) +{ +#if __CUDA_ARCH__ >= 700 + float4 ZERO; + const __half2 zero_h = __float2half2_rn(0.f); + __half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); +#pragma unroll + for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y * blockDim.z + threadIdx.z; // Sequence ID (0-127) + int d2 = threadIdx.y; // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + int d2_out_stride = padded_head_size * padded_seq_len; + int d0_out_stride = heads * d2_out_stride; + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = reinterpret_cast(output); + + vals_vec += (d0 * d0_stride); + vals_vec += (d1 * d1_stride); + vals_vec += (d2 * d2_stride); + + output_vec += (d1 * padded_head_size); + output_vec += (d0 * d0_out_stride); + output_vec += (d2 * d2_out_stride); + + if (d3 < d2_stride && d1 < seq_length) + output_vec[d3] = vals_vec[d3]; + else + output_vec[d3] = ZERO; + +#endif +} + +template +void launch_pad_add_transform_0213(T* output, + const T* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream); + +// [B S C*H] - > C * [B A S N] +template <> +void launch_pad_add_transform_0213(float* output, + const float* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream) +{ +} +template <> +void launch_pad_add_transform_0213<__half>(__half* output, + const __half* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream) +{ + hidden_dim >>= 3; + dim3 block_dim((padded_head_size >> 3), heads, 2); + dim3 grid_dim(batch_size, padded_seq_len / 2); + pad_add_transform_0213<<>>( + output, vals, hidden_dim, seq_length, padded_seq_len, heads, padded_head_size >> 3); +} + // Bias add template __global__ void bias_add_transform_0213(T* output, diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index 2fc1e7082662..64e490ef47fc 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -101,9 +101,12 @@ class Context { size_t total_size; if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); } - size_t activation_size = 16 * hidden_dim * batch_size; - size_t temp_size = batch_size * num_heads * prompt_len * prompt_len * elem_size / mp_size; - size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2; + int head_size = hidden_dim / num_heads; + int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); + size_t activation_size = 32 * (head_size * padded_head_size) * batch_size; + size_t temp_size = batch_size * num_heads * MAX_OUT_TOKENS * 2; + size_t cache_size = + num_layers * batch_size * ((head_size * padded_head_size) / mp_size) * 2; size_t minimal_requirements = temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE; if (_free_memory_size < minimal_requirements) { @@ -115,12 +118,12 @@ class Context { } _max_seq_len = ((_free_memory_size - minimal_requirements) / elem_size) / - (activation_size + cache_size); + (activation_size + temp_size + cache_size); _max_seq_len = std::min((size_t)MAX_OUT_TOKENS, _max_seq_len); - size_t workSpaceSize = - ((external_cache ? activation_size : (activation_size + cache_size))) * _max_seq_len * - elem_size + - temp_size; + size_t workSpaceSize = ((external_cache ? (activation_size + temp_size) + : (activation_size + temp_size + cache_size))) * + _max_seq_len * elem_size; + temp_size *= _max_seq_len * elem_size; if (rank == 0 && !_workspace) printf( "Free memory : %lu (Bytes) Total memory: %lu (Bytes) Setting maximum total " diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index 1f86e2d858d1..67479bbc0e50 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -172,3 +172,31 @@ void launch_bias_add_transform_0213(T* outputs, cudaStream_t stream, int trans_count, int max_out_tokens); +template +void pad_data(T* padded_output, + T* output, + int bsz, + int head_size, + int padded_head_size, + cudaStream_t stream); + +template +void pad_head_seq(T* padded_output, + T* output, + int bsz, + int seq_len, + int padded_seq_len, + int head_size, + int padded_head_size, + cudaStream_t stream); + +template +void launch_pad_add_transform_0213(T* output, + const T* vals, + int batch_size, + int hidden_dim, + int seq_length, + int padded_seq_len, + int heads, + int padded_head_size, + cudaStream_t stream); diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 89a8d8288455..6f9d90e3f743 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -14,7 +14,7 @@ from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization -from ..module_inject.replace_module import replace_transformer_layer +from ..module_inject.replace_module import replace_transformer_layer, generic_injection from ..comm.comm import init_distributed from ..pipe import PipelineModule from ..moe.utils import has_moe_layers @@ -89,7 +89,7 @@ def __init__(self, self.injection_dict = injection_dict self.mp_group = None self.mpu = mpu - self._validate_args(mpu) + self._validate_args(mpu, replace_with_kernel_inject) self.replace_method = replace_method self.quantize_merge_count = 1 self.quantization_scales = None @@ -125,7 +125,8 @@ def __init__(self, elif self.mp_world_size > 1: self._create_model_parallel_group() - moe, _ = has_moe_layers(self.module) + if isinstance(self.module, torch.nn.Module): + moe, _ = has_moe_layers(self.module) if moe and dist.get_world_size() > 1: self._create_ep_parallel_group(moe_experts) @@ -251,8 +252,9 @@ def _init_quantization_setting(self, quantization_setting): f"quantize_groups = {self.quantize_groups}", [0]) - def _validate_args(self, mpu): - if not isinstance(self.module, Module): + def _validate_args(self, mpu, replace_with_kernel_inject): + # TODO: to support SD pipeline we need to avoid this check for now + if replace_with_kernel_inject and not isinstance(self.module, Module): raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}") if not isinstance(self.mp_world_size, int) or self.mp_world_size < 1: raise ValueError(f"mp_size must be an int >= 1, got {self.mp_world_size}") @@ -357,33 +359,38 @@ def _apply_injection_policy(self, checkpoint = SDLoaderFactory.get_sd_loader_json( checkpoint_dir, self.checkpoint_engine) if checkpoint_dir is not None else None - replace_transformer_layer(client_module, - self.module, - triangular_masking=self.triangular_masking, - policy=injection_policy, - mp_size=self.mp_world_size, - mp_group=self.mp_group, - ep_group=self.ep_group, - expert_mp_group=self.expert_mp_group, - config=self.config, - fp16=(self.dtype == torch.half) - or (self.dtype == torch.int8), - training=False, - return_tuple=return_tuple, - quantize=(self.dtype == torch.int8), - quantize_settings=(self.quantization_scales, - self.quantize_merge_count, - self.mlp_extra_grouping, - self.quantize_groups), - replace_with_kernel_inject=replace_with_kernel_inject, - moe=moe, - moe_experts=moe_experts, - moe_type=moe_type, - training_mp_size=training_mp_size, - checkpoint_dict=checkpoint, - save_mp_checkpoint_path=save_mp_checkpoint_path, - base_dir=base_dir, - enable_cuda_graph=self.enable_cuda_graph) + + generic_injection(self.module, + fp16=(self.dtype == torch.half) or (self.dtype == torch.int8)) + + if isinstance(self.module, torch.nn.Module): + replace_transformer_layer( + client_module, + self.module, + triangular_masking=self.triangular_masking, + policy=injection_policy, + mp_size=self.mp_world_size, + mp_group=self.mp_group, + ep_group=self.ep_group, + expert_mp_group=self.expert_mp_group, + config=self.config, + fp16=(self.dtype == torch.half) or (self.dtype == torch.int8), + training=False, + return_tuple=return_tuple, + quantize=(self.dtype == torch.int8), + quantize_settings=(self.quantization_scales, + self.quantize_merge_count, + self.mlp_extra_grouping, + self.quantize_groups), + replace_with_kernel_inject=replace_with_kernel_inject, + moe=moe, + moe_experts=moe_experts, + moe_type=moe_type, + training_mp_size=training_mp_size, + checkpoint_dict=checkpoint, + save_mp_checkpoint_path=save_mp_checkpoint_path, + base_dir=base_dir, + enable_cuda_graph=self.enable_cuda_graph) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, @@ -478,6 +485,9 @@ def _choose_module_key(self, sd): return 'model' def _convert_to_dtype(self): + if not isinstance(self.module, torch.nn.Module): + return + if False: #self.dtype is torch.int8 and self.quantization_scales is None: quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping) model, self.quantization_scales = quantizer.model_quantize(self.module, diff --git a/deepspeed/module_inject/encoder.py b/deepspeed/module_inject/encoder.py new file mode 100644 index 000000000000..41d593daed92 --- /dev/null +++ b/deepspeed/module_inject/encoder.py @@ -0,0 +1,66 @@ +''' +Copyright 2022 The Microsoft DeepSpeed Team +''' +import torch + + +class DSClipEncoder(torch.nn.Module): + def __init__(self, enc): + super().__init__() + enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask + self.enc = enc + self.device = self.enc.device + self.dtype = self.enc.dtype + self.cuda_graph_created = False + + def _build_causal_attention_mask(self, bsz, seq_len, dtype): + mask = torch.empty(bsz, + seq_len, + seq_len, + dtype=dtype, + device=torch.cuda.current_device()) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) + mask = mask.unsqueeze(1) + return mask + + def _graph_replay(self, *inputs, **kwargs): + for i in range(len(inputs)): + if torch.is_tensor(inputs[i]): + self.static_inputs[i].copy_(inputs[i]) + for k in kwargs: + if torch.is_tensor(kwargs[k]): + self.static_kwargs[k].copy_(kwargs[k]) + self._cuda_graphs.replay() + return self.static_output + + def forward(self, *inputs, **kwargs): + if self.cuda_graph_created: + outputs = self._graph_replay(*inputs, **kwargs) + else: + self._create_cuda_graph(*inputs, **kwargs) + outputs = self._graph_replay(*inputs, **kwargs) + return outputs + + def _create_cuda_graph(self, *inputs, **kwargs): + # warmup to create the workspace and cublas handle + cuda_stream = torch.cuda.Stream() + cuda_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(cuda_stream): + for i in range(3): + ret = self._forward(*inputs, **kwargs) + torch.cuda.current_stream().wait_stream(cuda_stream) + + # create cuda_graph and assign static_inputs and static_outputs + self._cuda_graphs = torch.cuda.CUDAGraph() + self.static_inputs = inputs + self.static_kwargs = kwargs + + with torch.cuda.graph(self._cuda_graphs): + self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) + + self.cuda_graph_created = True + + def _forward(self, *inputs, **kwargs): + + return self.enc(*inputs, **kwargs) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 158507e077e0..d7fa50eca4ce 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -4,7 +4,7 @@ import deepspeed import deepspeed.ops.transformer as transformer_inference from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, BLOOMLayerPolicy -from .replace_policy import replace_policies +from .replace_policy import replace_policies, generic_policies #from ..runtime.weight_quantizer import WeightQuantization from deepspeed import comm as dist from torch import nn @@ -187,6 +187,93 @@ def quantize(self, inputs, qkv=True, count=1, parallel_dim=0): return out +def _module_match(module): + for policy in generic_policies: + policy = policy() + if policy.match(module): + return policy + return None + + +def generic_injection(module, fp16=False): + def replace_attn(child, policy, layer_id): + policy_attn = policy.attention(child) + if policy_attn is None: + return child + if len(policy_attn) == 5: + qkvw, attn_ow, attn_ob, hidden_size, heads = policy_attn + else: + qw, kw, vw, attn_ow, attn_ob, hidden_size, heads = policy_attn + + config = transformer_inference.DeepSpeedInferenceConfig( + hidden_size=hidden_size, + heads=heads, + fp16=fp16, + triangular_masking=False, + ) + attn_module = transformer_inference.DeepSpeedAttention(config) + + def transpose(data): + data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1)) + data = data.reshape(data.shape[-1], data.shape[-2]) + data.to(torch.cuda.current_device()) + return data + + if len(policy_attn) == 5: + attn_module.attn_qkvw.data = transpose(qkvw.data) + else: + attn_module.attn_qkvw = None + attn_module.attn_qw.data = transpose(qw.data) + attn_module.attn_kw.data = transpose(kw.data) + attn_module.attn_vw.data = transpose(vw.data) + + attn_module.attn_qkvb = None + attn_module.attn_ow.data = transpose(attn_ow.data) + attn_module.attn_ob.data.copy_(attn_ob.data.to(torch.cuda.current_device())) + return attn_module + + if isinstance(module, torch.nn.Module): + pass + else: + try: + import diffusers + cross_attention = diffusers.models.attention.CrossAttention + new_policies = {cross_attention: replace_attn} + except ImportError: + new_policies = {} + + #replace_transformer_layer(None, + # module.text_encoder, + # training=False, + # replace_with_kernel_inject=True, + # triangular_masking=True) + #from .encoder import DSClipEncoder + #cg_encoder = DSClipEncoder(module.text_encoder) + #setattr(module, 'text_encoder', cg_encoder) + for name in module.__dict__.keys(): + sub_module = getattr(module, name) + policy = _module_match(sub_module) + + if policy is not None: + + def _replace_module(module, policy, layer_id=0): + for name, child in module.named_children(): + if child.__class__ in new_policies: + replaced_module = new_policies[child.__class__](child, + policy, + layer_id) + setattr(module, name, replaced_module) + layer_id += 1 + else: + layer_id = _replace_module(child, policy, layer_id=layer_id) + return layer_id + + _replace_module(sub_module, policy) + new_module = policy.apply(sub_module) + print(f"**** found and replaced {name} w. {type(new_module)}") + setattr(module, name, new_module) + + def replace_transformer_layer(orig_layer_impl, model, policy=None, @@ -251,6 +338,7 @@ def replace_transformer_layer(orig_layer_impl, Returns: Updated nn.module with replaced transformer layers """ + mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group, mp_size=mp_size) #, out_dim=0, in_dim=1) diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index cb7c4818961a..6d72e9e46468 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -1,3 +1,6 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' from abc import ABC import torch @@ -10,6 +13,61 @@ class DSPolicy(ABC): + _orig_layer_class = None + + def __init__(self): + self.cuda_graph_supported = False + + def attention(self): + """ + Returns attention qkv and dense parameters + weight: (3*hidden, hidden) and (hidden, hidden) + bias: (3*hidden) and (hidden) + """ + raise NotImplementedError + + +class UNetPolicy(DSPolicy): + def __init__(self): + super().__init__() + try: + import diffusers + self._orig_layer_class = diffusers.models.unet_2d_condition.UNet2DConditionModel + except ImportError: + self._orig_layer_class = None + + def match(self, module): + return isinstance(module, self._orig_layer_class) + + def apply(self, module): + from .unet import DSUNet + return DSUNet(module) + + def attention(self, client_module): + qw = client_module.to_q.weight + kw = client_module.to_k.weight + vw = client_module.to_v.weight + + if qw.shape[1] == kw.shape[1]: + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) + + return qkvw, \ + client_module.to_out[0].weight, \ + client_module.to_out[0].bias, \ + qw.shape[-1], \ + client_module.heads + else: + #return None + #kvw = Parameter(torch.cat((kw, vw), dim=0), requires_grad=False) + return qw, \ + kw, vw, \ + client_module.to_out[0].weight, \ + client_module.to_out[0].bias, \ + qw.shape[-1], \ + client_module.heads + + +class TransformerPolicy(DSPolicy): # a static class variable containing the HuggingFace model configuration. # see e.g., transformers.models.opt.configuration_opt.OPTConfig hf_model_config = None @@ -24,7 +82,7 @@ def __init__( mlp_act_func_type=ActivationFuncType.GELU, # applies layer norm before attention if `pre_attn_norm` is set to True pre_attn_norm=True): - self.cuda_graph_supported = False + super().__init__() self.inference = inference self.linear_layer = linear_layer self.scale_attention = scale_attention @@ -63,9 +121,7 @@ def layerNorm(self): raise NotImplementedError -class HFBertLayerPolicy(DSPolicy): - _orig_layer_class = None - +class HFBertLayerPolicy(TransformerPolicy): def __init__(self, client_module, inference=False): super().__init__(inference, pre_attn_norm=False) self.client_module = client_module @@ -127,9 +183,57 @@ def layerNorm(self): transformer_layernorm.bias -class HFGPTNEOLayerPolicy(DSPolicy): - _orig_layer_class = None +class HFCLIPLayerPolicy(TransformerPolicy): + def __init__(self, client_module, inference=False): + super().__init__(inference, pre_attn_norm=True, scale_attention=False) + self.client_module = client_module + self.cuda_graph_supported = True + + if HFCLIPLayerPolicy._orig_layer_class is None: + try: + import transformers + HFCLIPLayerPolicy._orig_layer_class = transformers.models.clip.modeling_clip.CLIPEncoderLayer + except: + HFCLIPLayerPolicy._orig_layer_class = None + def get_hidden_heads(self): + return self.client_module.self_attn.q_proj.weight.shape[1], \ + self.client_module.self_attn.num_heads + + def attention(self): + qw = self.client_module.self_attn.q_proj.weight + qb = self.client_module.self_attn.q_proj.bias + kw = self.client_module.self_attn.k_proj.weight + kb = self.client_module.self_attn.k_proj.bias + vw = self.client_module.self_attn.v_proj.weight + vb = self.client_module.self_attn.v_proj.bias + + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) + qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False) + + return self.linear_layer, \ + qkvw, \ + qkvb, \ + self.client_module.self_attn.out_proj.weight, \ + self.client_module.self_attn.out_proj.bias, \ + self.scale_attention, \ + self.is_megatron_v2 + + def mlp(self): + return self.linear_layer, \ + self.client_module.mlp.fc1.weight, \ + self.client_module.mlp.fc1.bias, \ + self.client_module.mlp.fc2.weight, \ + self.client_module.mlp.fc2.bias + + def layerNorm(self): + return self.client_module.layer_norm2.weight, \ + self.client_module.layer_norm2.bias, \ + self.client_module.layer_norm1.weight, \ + self.client_module.layer_norm1.bias + + +class HFGPTNEOLayerPolicy(TransformerPolicy): def __init__(self, client_module, inference=True): super().__init__(inference, scale_attention=False) self.client_module = client_module @@ -172,7 +276,7 @@ def layerNorm(self): self.client_module.ln_1.bias -class HFGPTJLayerPolicy(DSPolicy): +class HFGPTJLayerPolicy(TransformerPolicy): _orig_layer_class = None def __init__(self, client_module, inference=True): @@ -217,7 +321,7 @@ def layerNorm(self): self.client_module.ln_1.bias -class MegatronLayerPolicy(DSPolicy): +class MegatronLayerPolicy(TransformerPolicy): _orig_layer_class = None version = 0 moe_type = 'standard' @@ -297,7 +401,7 @@ def layerNorm(self): self.client_module.input_layernorm.bias -class HFGPT2LayerPolicy(DSPolicy): +class HFGPT2LayerPolicy(TransformerPolicy): _orig_layer_class = None def __init__(self, client_module, inference=True): @@ -337,7 +441,7 @@ def layerNorm(self): self.client_module.ln_1.bias -class BLOOMLayerPolicy(DSPolicy): +class BLOOMLayerPolicy(TransformerPolicy): _orig_layer_class = None def __init__(self, client_module, inference=True): @@ -379,7 +483,7 @@ def layerNorm(self): self.client_module.input_layernorm.bias -class GPTNEOXLayerPolicy(DSPolicy): +class GPTNEOXLayerPolicy(TransformerPolicy): _orig_layer_class = None version = 0 @@ -433,7 +537,7 @@ def layerNorm(self): self.client_module.input_layernorm.bias -class HFOPTLayerPolicy(DSPolicy): +class HFOPTLayerPolicy(TransformerPolicy): _orig_layer_class = None def __init__(self, client_module, inference=True): @@ -490,6 +594,7 @@ def layerNorm(self): self.client_module.self_attn_layer_norm.bias +# transformer-based policies replace_policies = [ HFBertLayerPolicy, HFGPTNEOLayerPolicy, @@ -499,4 +604,8 @@ def layerNorm(self): HFGPT2LayerPolicy, BLOOMLayerPolicy, HFOPTLayerPolicy, + HFCLIPLayerPolicy, ] + +# non-transformer-based policies +generic_policies = [UNetPolicy] diff --git a/deepspeed/module_inject/unet.py b/deepspeed/module_inject/unet.py new file mode 100644 index 000000000000..a667d94e6c83 --- /dev/null +++ b/deepspeed/module_inject/unet.py @@ -0,0 +1,82 @@ +''' +Copyright 2022 The Microsoft DeepSpeed Team +''' +import torch +import diffusers + + +class DSUNet(torch.nn.Module): + def __init__(self, unet): + super().__init__() + self.unet = unet + # SD pipeline accesses this attribute + self.in_channels = unet.in_channels + self._traced_unet = None + self._trace_enabled = False + self.device = self.unet.device + self.dtype = self.unet.dtype + self.fwd_count = 0 + self.unet.requires_grad_(requires_grad=False) + self.unet.to(memory_format=torch.channels_last) + self.cuda_graph_created = False + + def _graph_replay(self, *inputs, **kwargs): + for i in range(len(inputs)): + if torch.is_tensor(inputs[i]): + self.static_inputs[i].copy_(inputs[i]) + for k in kwargs: + if torch.is_tensor(kwargs[k]): + self.static_kwargs[k].copy_(kwargs[k]) + self._cuda_graphs.replay() + return self.static_output + + def forward(self, *inputs, **kwargs): + if self.cuda_graph_created: + outputs = self._graph_replay(*inputs, **kwargs) + else: + self._create_cuda_graph(*inputs, **kwargs) + outputs = self._graph_replay(*inputs, **kwargs) + return outputs + + def _create_cuda_graph(self, *inputs, **kwargs): + # warmup to create the workspace and cublas handle + cuda_stream = torch.cuda.Stream() + cuda_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(cuda_stream): + for i in range(3): + ret = self._forward(*inputs, **kwargs) + torch.cuda.current_stream().wait_stream(cuda_stream) + + # create cuda_graph and assign static_inputs and static_outputs + self._cuda_graphs = torch.cuda.CUDAGraph() + self.static_inputs = inputs + self.static_kwargs = kwargs + + with torch.cuda.graph(self._cuda_graphs): + self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) + + self.cuda_graph_created = True + + def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True): + if self._trace_enabled: + if self._traced_unet is None: + print("Unet: start tracing with Nvfuser") + # force return tuple instead of dict + self._traced_unet = torch.jit.trace( + lambda _sample, + _timestamp, + _encoder_hidden_states: self.unet(_sample, + _timestamp, + _encoder_hidden_states, + return_dict=False), + (sample, + timestamp, + encoder_hidden_states)) + return self.unet(sample, timestamp, encoder_hidden_states) + else: + # convert return type to UNet2DConditionOutput + out_sample, *_ = self._traced_unet(sample, timestamp, encoder_hidden_states) + return diffusers.models.unet_2d_condition.UNet2DConditionOutput( + out_sample) + else: + return self.unet(sample, timestamp, encoder_hidden_states, return_dict) diff --git a/deepspeed/moe/utils.py b/deepspeed/moe/utils.py index 043d2626d43c..16f59c4fe70b 100644 --- a/deepspeed/moe/utils.py +++ b/deepspeed/moe/utils.py @@ -6,6 +6,7 @@ def has_moe_layers(m): has_moe = False num_experts = 0 + for _, module in m.named_modules(): if isinstance(module, MoE): has_moe = True diff --git a/deepspeed/ops/transformer/__init__.py b/deepspeed/ops/transformer/__init__.py index 28c8de68dd8b..49b543551f4b 100755 --- a/deepspeed/ops/transformer/__init__.py +++ b/deepspeed/ops/transformer/__init__.py @@ -1,3 +1,4 @@ from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from .inference.transformer_inference import DeepSpeedTransformerInference, DeepSpeedInferenceConfig from .inference.moe_inference import DeepSpeedMoEInferenceConfig, DeepSpeedMoEInference +from .inference.attention import DeepSpeedAttention diff --git a/deepspeed/ops/transformer/inference/__init__.py b/deepspeed/ops/transformer/inference/__init__.py index 66f8124015cb..f5b042d7fdfb 100644 --- a/deepspeed/ops/transformer/inference/__init__.py +++ b/deepspeed/ops/transformer/inference/__init__.py @@ -1,2 +1,3 @@ from .transformer_inference import DeepSpeedTransformerInference, DeepSpeedInferenceConfig from .moe_inference import DeepSpeedMoEInferenceConfig, DeepSpeedMoEInference +from .attention import DeepSpeedAttention diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py new file mode 100644 index 000000000000..f8bad34d5bff --- /dev/null +++ b/deepspeed/ops/transformer/inference/attention.py @@ -0,0 +1,296 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' +import math +import torch +from torch.autograd import Function +from ... import op_builder +import torch.nn as nn +from packaging import version as pkg_version +from deepspeed.utils.logging import log_dist +# Cuda modules will be imported if needed +inference_cuda_module = None +minus_inf = -10000.0 +triton_flash_attn = None + + +def load_triton_flash_attn(): + global triton_flash_attn + try: + import triton + except ImportError: + raise ImportError("Please install triton 2.0+ or `pip install deepspeed[sd]`") + + if pkg_version.parse(triton.__version__) < pkg_version.parse("2.0"): + raise ImportError("Please install triton 2.0+ or `pip install deepspeed[sd]`") + + from .triton_ops import triton_flash_attn + + +class DeepSpeedAttentionFunction(Function): + @staticmethod + def forward(ctx, + input, + context, + input_mask, + config, + attn_qkvw, + attn_qw, + attn_kw, + attn_vw, + attn_qkvb, + num_attention_heads_per_partition, + norm_factor, + hidden_size_per_partition, + attn_ow, + attn_ob, + score_context_func, + linear_func, + triton_flash_attn_kernel): + def _transpose_for_context(x): + x = x.permute(0, 2, 1, 3) + new_x_layer_shape = x.size()[:-2] + \ + (hidden_size_per_partition,) + return x.reshape(*new_x_layer_shape) + + def _transpose_for_scores(x): + attention_head_size = x.shape[-1] // num_attention_heads_per_partition + new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, + attention_head_size) + x = x.reshape(*new_x_shape) + x = x.permute(0, 2, 1, 3) + return x.contiguous() + + def compute_attention(qkv_out, input_mask): + no_masking = input_mask is None + + head_size = (qkv_out.shape[-1] // 3 // num_attention_heads_per_partition) + if no_masking: + input_mask = torch.empty(1) + + context_layer, _, _ = score_context_func( + qkv_out, + ((1 - input_mask).to(qkv_out.dype) * + minus_inf) if input_mask.dtype == torch.int64 else input_mask, + config.rotary_dim, + config.rotate_half, + config.rotate_every_two, + num_attention_heads_per_partition, + (1 / norm_factor if config.scale_attention else 1.0), + config.triangular_masking, + config.local_attention, + config.window_size, + no_masking, + config.layer_id, + DeepSpeedAttention.layer_id, + torch.empty(1)) + return context_layer + + def selfAttention_fp(input, context, input_mask): + if config.fp16 and input.dtype == torch.float32: + input = input.half() + head_size = input.shape[-1] // config.heads + do_flash_attn = (head_size <= 128) + scale = (1 / norm_factor) * (1 / norm_factor) + if context == None: + qkv_out = linear_func(input, + attn_qkvw, + attn_qkvb if attn_qkvb is not None else attn_qkvw, + attn_qkvb is not None, + True, + do_flash_attn, + config.heads, + DeepSpeedAttention.layer_id) + if do_flash_attn: + context_layer = triton_flash_attn_kernel(qkv_out[0], + qkv_out[1], + qkv_out[2], + scale, + input.shape[-2] % 128 == 0) + context_layer = _transpose_for_context(context_layer[:,:,:,:head_size]) + else: + context_layer = compute_attention(qkv_out, input_mask) + else: + query = torch.matmul(input, attn_qw) + key = torch.matmul(context, attn_kw) + value = torch.matmul(context, attn_vw) + query, key, value = inference_cuda_module.pad_transform_fp16(query, key, value, config.heads, do_flash_attn) + if do_flash_attn: + context_layer = triton_flash_attn_kernel(query, + key, + value, + scale, + input.shape[-2] % 128 == 0) + context_layer = _transpose_for_context(context_layer[:,:,:,:head_size]) + else: + attention_scores = (torch.matmul(query, + key.transpose(-1, + -2)) * + scale).softmax(dim=-1) + context_layer = _transpose_for_context( + torch.matmul(attention_scores, + value)) + + output = linear_func(context_layer, + attn_ow, + attn_ob, + attn_ob is not None, + True, + False, + config.heads, + DeepSpeedAttention.layer_id) + return output + + output = selfAttention_fp(input, context, input_mask) + + return output + + @staticmethod + def backward(ctx, grad_output, grad_output1, grad_output2, grad_output3): + raise RuntimeError('You are running with DeepSpeed Inference mode. \ + Please switch to Training mode for running backward!') + + +class DeepSpeedAttention(nn.Module): + """Initialize the DeepSpeed Transformer Layer. + Arguments: + layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers, + layer_id will be 0,1,2...23 when each layer object is instantiated + config: An object of DeepSpeedInferenceConfig + """ + layer_id = 0 + + def __init__( + self, + config, + ): + super(DeepSpeedAttention, self).__init__() + + self.config = config + self.config.layer_id = DeepSpeedAttention.layer_id + DeepSpeedAttention.layer_id += 1 + device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 + + data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float + data_type_fp = torch.half if config.fp16 else torch.float + global inference_cuda_module + if inference_cuda_module is None: + builder = op_builder.InferenceBuilder() + inference_cuda_module = builder.load() + + if DeepSpeedAttention.layer_id == 1: + log_dist(f"DeepSpeed-Attention config: {self.config.__dict__}", [0]) + + self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, + qkv_size_per_partition, + dtype=data_type, + device=device), + requires_grad=False) + self.attn_kw = nn.Parameter(torch.empty(self.config.hidden_size, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + self.attn_vw = nn.Parameter(torch.empty(self.config.hidden_size, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + self.attn_qw = nn.Parameter(torch.empty(self.config.hidden_size, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, + dtype=data_type_fp, + device=device), + requires_grad=False) + out_size_per_partition = self.config.hidden_size // self.config.mp_size + self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + + self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type_fp, + device=device), + requires_grad=False) + if triton_flash_attn is None: + load_triton_flash_attn() + self.triton_flash_attn_kernel = triton_flash_attn() + self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size + self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size + self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads + + self.norm_factor = math.sqrt( + math.sqrt(self.config.hidden_size // self.config.heads)) + + self.score_context_func = inference_cuda_module.softmax_context_fp32 if (not config.fp16) else \ + inference_cuda_module.softmax_context_fp16 + self.linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ + inference_cuda_module.linear_layer_fp32 + self.cuda_graph_created = False + self.enable_cuda_graph = False + + def _graph_replay(self, *inputs, **kwargs): + for i in range(len(inputs)): + if torch.is_tensor(inputs[i]): + self.static_inputs[i].copy_(inputs[i]) + for k in kwargs: + if torch.is_tensor(kwargs[k]): + self.static_kwargs[k].copy_(kwargs[k]) + self._cuda_graphs.replay() + return self.static_output + + def _create_cuda_graph(self, *inputs, **kwargs): + # warmup to create the workspace and cublas handle + cuda_stream = torch.cuda.Stream() + cuda_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(cuda_stream): + for i in range(3): + ret = self._forward(*inputs, **kwargs) + torch.cuda.current_stream().wait_stream(cuda_stream) + + # create cuda_graph and assign static_inputs and static_outputs + self._cuda_graphs = torch.cuda.CUDAGraph() + self.static_inputs = inputs + self.static_kwargs = kwargs + + with torch.cuda.graph(self._cuda_graphs): + self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) + + self.cuda_graph_created = True + + def forward(self, *inputs, **kwargs): + if self.enable_cuda_graph: + if self.cuda_graph_created: + outputs = self._graph_replay(*inputs, **kwargs) + else: + self._create_cuda_graph(*inputs, **kwargs) + outputs = self._graph_replay(*inputs, **kwargs) + else: + outputs = self._forward(*inputs, **kwargs) + return outputs + + def _forward(self, input, context=None, input_mask=None): + output = DeepSpeedAttentionFunction.apply(input, + context, + input_mask, + self.config, + self.attn_qkvw, + self.attn_qw, + self.attn_kw, + self.attn_vw, + self.attn_qkvb, + self.num_attention_heads_per_partition, + self.norm_factor, + self.hidden_size_per_partition, + self.attn_ow, + self.attn_ob, + self.score_context_func, + self.linear_func, + self.triton_flash_attn_kernel) + + return output diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index cc7ed35a33a5..a7e043676875 100644 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -408,10 +408,13 @@ def selfAttention_fp(): if not config.pre_layer_norm: linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ inference_cuda_module.linear_layer_fp32 - qkv_out = linear_func(input, attn_qkvw, attn_qkvb, + attn_qkvb is not None, + False, + False, + num_attention_heads_per_partition, DeepSpeedTransformerInference.layer_id) else: qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \ diff --git a/deepspeed/ops/transformer/inference/triton_ops.py b/deepspeed/ops/transformer/inference/triton_ops.py new file mode 100644 index 000000000000..423a2ff1134d --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton_ops.py @@ -0,0 +1,151 @@ +""" +Inspired by original Triton implementation: +https://github.com/openai/triton/blob/b244db06da24a87453a40ad35b085ee37dac3705/python/tutorials/06-fused-attention.py +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + TMP, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + off_v = off_hz * stride_vh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + t_ptrs = TMP + off_hz * N_CTX + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + # loop over k, v and update accumulator + for start_n in range(0, N_CTX, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + start_n * stride_kn) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + qk *= sm_scale + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + start_n * stride_vk) + p = p.to(tl.float16) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + +class triton_flash_attn(torch.nn.Module): + def __init__(self, ): + super(triton_flash_attn, self).__init__() + + def forward(self, q, k, v, sm_scale, block_128=True): + BLOCK = 128 if block_128 else 64 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) + tmp = torch.empty((q.shape[0] * q.shape[1], + q.shape[2]), + device=q.device, + dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + tmp, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + k.shape[0], + k.shape[1], + k.shape[2], + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, + num_warps=num_warps, + num_stages=1, + ) + return o diff --git a/requirements/requirements-sd.txt b/requirements/requirements-sd.txt new file mode 100644 index 000000000000..c9026206a737 --- /dev/null +++ b/requirements/requirements-sd.txt @@ -0,0 +1,2 @@ +diffusers +triton==2.0.0.dev20221005