diff --git a/csrc/transformer/inference/csrc/dequantize.cu b/csrc/transformer/inference/csrc/dequantize.cu index 4ddaabda3eb7..3409f7ba7de8 100644 --- a/csrc/transformer/inference/csrc/dequantize.cu +++ b/csrc/transformer/inference/csrc/dequantize.cu @@ -108,3 +108,90 @@ template void launch_dequantize<__half>(__half*, unsigned, unsigned, cudaStream_t); + +__global__ void dequantize_kernel(float* output, + const int8_t* input, + const float* qscale, + int hidden_dim, + unsigned merge_hidden, + int cnt) +{ +} + +__global__ void dequantize_kernel(__half* output, + const int8_t* input, + const float* qscale, + unsigned hidden_dim, + unsigned merge_hidden, + int cnt) +{ + unsigned bid = blockIdx.x * gridDim.y + blockIdx.y; + unsigned tid = threadIdx.x; + + float local_scale = qscale[blockIdx.x]; + + const float* input_cast = reinterpret_cast(input); + float2* output_cast = reinterpret_cast(output); + + input_cast += bid * merge_hidden; + output_cast += bid * merge_hidden; + + for (int c = 0; c < cnt; c++) { + if (tid < merge_hidden) { + float q = input_cast[tid]; + int8_t* q_int8 = (int8_t*)&q; + + float2 q_f; + __half* q_h = (__half*)&q_f; + + q_h[0] = __float2half(local_scale * (float)q_int8[0]); + q_h[1] = __float2half(local_scale * (float)q_int8[1]); + q_h[2] = __float2half(local_scale * (float)q_int8[2]); + q_h[3] = __float2half(local_scale * (float)q_int8[3]); + // q_h[4] = __float2half(local_scale * (float)q_int8[4]); + // q_h[5] = __float2half(local_scale * (float)q_int8[5]); + // q_h[6] = __float2half(local_scale * (float)q_int8[6]); + // q_h[7] = __float2half(local_scale * (float)q_int8[7]); + output_cast[tid] = q_f; + tid += blockDim.x; + } + } +} + +template +void launch_dequantize(T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + cudaStream_t stream) +{ + unsigned threads = 1024; + hidden_dim /= 4; + unsigned hid_cnt = threads / hidden_dim; + unsigned thd_cnt = (hidden_dim - 1) / threads + 1; + hid_cnt = hid_cnt > 0 ? hid_cnt : 1; + + unsigned blocks = output_size / hid_cnt / groups; + dim3 block_dims(threads); + dim3 grid_dims(groups, blocks); + + dequantize_kernel<<>>( + output, input, qscale, hidden_dim, hid_cnt * hidden_dim, thd_cnt); +} + +template void launch_dequantize(float*, + const int8_t*, + const float*, + unsigned, + unsigned, + unsigned, + cudaStream_t); +template void launch_dequantize<__half>(__half*, + const int8_t*, + const float*, + unsigned, + unsigned, + unsigned, + cudaStream_t); diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index ccd5469826d0..94f5ad3b51c5 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -1,5 +1,6 @@ #include "custom_cuda_layers.h" +namespace cg = cooperative_groups; #define MAX_CAP 4 #define MAX_SEQ 2048 diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index ccda95918255..a4202043bbd8 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -558,15 +558,55 @@ void ds_layernorm_internal(T* workspace, Context::Instance().GetCurrentStream()); } +template +void quantized_gemm(at::Tensor& output, + T* input, + at::Tensor& weight, + at::Tensor& qscale, + int groups, + int bsz) +{ + auto weight16 = at::empty({weight.size(0), weight.size(1)}, output.options()); + + launch_dequantize((T*)weight16.data_ptr(), + (int8_t*)weight.data_ptr(), + (float*)qscale.data_ptr(), + weight.size(0), + weight.size(1), + groups, + Context::Instance().GetCurrentStream()); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_T, + CUBLAS_OP_N, + weight.size(0), + bsz, + weight.size(1), + &alpha, + &gemm_beta, + (T*)weight16.data_ptr(), + (T*)input, + (T*)output.data_ptr(), +#ifdef __HIP_PLATFORM_HCC__ + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif +} + template at::Tensor qkv_unfused_cublas(at::Tensor& output, at::Tensor& input, at::Tensor& weight, + at::Tensor& q_scale, at::Tensor& bias, at::Tensor& gamma, at::Tensor& beta, const float epsilon, - bool add_bias) + bool add_bias, + bool q_int8) { int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); @@ -574,48 +614,55 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, ds_layernorm_internal(workspace, input, gamma, beta, epsilon); // cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream()); - float alpha = (T)1.0; - float gemm_beta = (T)0.0; + if (q_int8) { + quantized_gemm(output, workspace, weight, q_scale, q_scale.size(0), bsz); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - workspace, - (T*)output.data_ptr(), + cublasSetStream(Context::Instance().GetCublasHandle(), + Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + workspace, + (T*)output.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ - rocblas_gemm_algo_standard); + rocblas_gemm_algo_standard); #else - CUBLAS_GEMM_DEFAULT_TENSOR_OP); + CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif + } if (add_bias) launch_bias_add((T*)output.data_ptr(), (T*)bias.data_ptr(), - weight.size(1), + q_int8 ? weight.size(0) : weight.size(1), bsz, Context::Instance().GetCurrentStream()); - return torch::from_blob(workspace, input.sizes(), input.options()); } template std::vector ds_qkv_gemm(at::Tensor& input, at::Tensor& weight, + at::Tensor& q_scale, at::Tensor& bias, at::Tensor& gamma, at::Tensor& beta, const float epsilon, bool add_bias, - unsigned num_layers) + unsigned num_layers, + bool q_int8) { int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); + int out_size = q_int8 ? weight.size(0) : weight.size(1); if (!workspace) { cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); @@ -628,9 +675,9 @@ std::vector ds_qkv_gemm(at::Tensor& input, .device(at::kCUDA) .requires_grad(false); - auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); - auto inp_norm = - qkv_unfused_cublas(output, input, weight, bias, gamma, beta, epsilon, add_bias); + auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); + auto inp_norm = qkv_unfused_cublas( + output, input, weight, q_scale, bias, gamma, beta, epsilon, add_bias, q_int8); return {output, inp_norm}; } @@ -654,20 +701,18 @@ void quantized_gemm(at::Tensor& output, launch_dequantize((T*)weight16.data_ptr(), (int8_t*)weight.data_ptr(), (float*)qscale.data_ptr(), - weight.size(1), weight.size(0), + weight.size(1), groups, merge_count, Context::Instance().GetCurrentStream()); - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - float alpha = (T)1.0; float gemm_beta = (T)0.0; cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_T, CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), + weight.size(0), bsz, input.size(2), &alpha, @@ -796,7 +841,11 @@ at::Tensor ds_linear_layer_int8(at::Tensor& input, } template -at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op) +at::Tensor ds_vector_matmul(at::Tensor& input, + at::Tensor& weight, + bool async_op, + at::Tensor& q_scale, + bool q_int8) { auto input_cont = input.contiguous(); auto options = at::TensorOptions() @@ -805,28 +854,33 @@ at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op .device(at::kCUDA) .requires_grad(false); - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + int out_size = q_int8 ? weight.size(0) : weight.size(1); int bsz = input_cont.size(0) * input_cont.size(1); - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), - Context::Instance().GetCurrentStream(async_op)); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input_cont.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)output.data_ptr(), + auto output = at::empty({input_cont.size(0), input_cont.size(1), out_size}, options); + if (q_int8) { + quantized_gemm(output, (T*)input_cont.data_ptr(), weight, q_scale, q_scale.size(0), bsz); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), + Context::Instance().GetCurrentStream(async_op)); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input_cont.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)output.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ - rocblas_gemm_algo_standard); + rocblas_gemm_algo_standard); #else - CUBLAS_GEMM_DEFAULT_TENSOR_OP); + CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif + } return output; } @@ -862,6 +916,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, const float epsilon, bool preLayerNorm, bool mlp_after_attn, + at::Tensor& q_scale, + bool q_int8, ActivationFuncType act_func_type) { int bsz = input.size(0) * input.size(1); @@ -881,36 +937,40 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, mlp_after_attn, Context::Instance().GetCurrentStream()); - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)inp_norm.data_ptr(), - (T*)output.data_ptr(), + if (q_int8) { + quantized_gemm(output, (T*)inp_norm.data_ptr(), weight, q_scale, q_scale.size(0), bsz); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), + Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)inp_norm.data_ptr(), + (T*)output.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ - rocblas_gemm_algo_standard); + rocblas_gemm_algo_standard); #else - CUBLAS_GEMM_DEFAULT_TENSOR_OP); + CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif - + } if (act_func_type == ActivationFuncType::GELU) { launch_bias_gelu((T*)output.data_ptr(), (T*)bias.data_ptr(), - weight.size(1), + q_int8 ? weight.size(0) : weight.size(1), bsz, Context::Instance().GetCurrentStream()); } else if (act_func_type == ActivationFuncType::ReLU) { launch_bias_relu((T*)output.data_ptr(), (T*)bias.data_ptr(), - weight.size(1), + q_int8 ? weight.size(0) : weight.size(1), bsz, Context::Instance().GetCurrentStream()); } @@ -929,6 +989,8 @@ std::vector ds_mlp_gemm(at::Tensor& input, const float epsilon, bool preLayerNorm, bool mlp_after_attn, + at::Tensor& q_scale, + bool q_int8, int activation_type) { auto input_cont = input.contiguous(); @@ -938,7 +1000,10 @@ std::vector ds_mlp_gemm(at::Tensor& input, .device(at::kCUDA) .requires_grad(false); - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + int out_size = q_int8 ? weight.size(0) : weight.size(1); + auto output = at::from_blob((T*)Context::Instance().GetWorkSpace(), + {input_cont.size(0), input_cont.size(1), out_size}, + options); int bsz = input_cont.size(0) * input_cont.size(1); auto act_func_type = static_cast(activation_type); @@ -953,6 +1018,8 @@ std::vector ds_mlp_gemm(at::Tensor& input, epsilon, preLayerNorm, mlp_after_attn, + q_scale, + q_int8, act_func_type); return {output, res_add}; @@ -984,20 +1051,6 @@ std::vector ds_mlp_gemm_int8(at::Tensor& input, auto inp_norm = at::empty_like(input_cont); auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm); - // computing the blocking across K dimension - // launch_residual_layer_norm((T*)inp_norm.data_ptr(), - // (T*)residual_add.data_ptr(), - // (T*)input_cont.data_ptr(), - // (T*)residual.data_ptr(), - // (T*)input_bias.data_ptr(), - // (T*)gamma.data_ptr(), - // (T*)beta.data_ptr(), - // epsilon, - // bsz, - // input_cont.size(2), - // preLayerNorm, - // Context::Instance().GetCurrentStream()); - quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); launch_bias_gelu((T*)output.data_ptr(), (T*)bias.data_ptr(), diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index 39b91ef87cf3..32708bfe0b46 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -109,6 +109,14 @@ void launch_dequantize(T* output, cudaStream_t stream); template +void launch_dequantize(T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + cudaStream_t stream); +template void launch_gptj_residual_add(T* input, T* output, T* attn, diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 1a5ee744d9bd..8d3fa725164c 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -242,7 +242,8 @@ def init_inference(model, moe_type='standard', args=None, enable_cuda_graph=False, - save_mp_checkpoint_path=None): + save_mp_checkpoint_path=None, + base_dir=""): """Initialize the DeepSpeed InferenceEngine. Arguments: @@ -278,7 +279,19 @@ def init_inference(model, of groups used in quantization. A tuple is passed in if we want to mention that there is extra-grouping for the MLP part of a Transformer layer (e.g. (True, 8) shows we quantize the model using 8 groups for all the network except the MLP part that we use 8 extra grouping). - replace_with_kernel_inject: If set we inject kernel as we initialize the inference-engine + replace_with_kernel_inject: this flag need to be set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. Otherwise, + the injection_dict provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection) + return_tuple: Specify whether or not the transformer layers need to return a tuple or a Tensor. It is set to True by default (returning a tuple). + ep_size: The expert-parallelism size which is used for partitioning the experts across the GPUs in the expert-parallel group. + moe: Specify if the type of Transformer is MoE. It is set to False by default. + moe_experts: The global number of experts used in an MoE layer. + moe_type: Specify the type of MoE layer. We have two types of MoE layer: 'Standard' and 'Residual'. It is set to 'Standard' type by default. + args: All the arguments used for launching the inference api that can be useful at the inference-engine for injecting the optimizations. + enable_cuda_graph: use this flag for capturing the CUDA-Graph of the inference ops, so that it can run faster using the graph replay method, + this is set to False by default + save_mp_checkpoint_path: The path for which we want to save the loaded model with a checkpoint. This feature is used for adjusting the + parallelism degree to help alleviate the model loading overhead. It does not save any new checkpoint if no path is passed. + base_dir: This shows the root directory under which all the checkpoint files exists. This can be passed through the json config too. Returns: A deepspeed.InferenceEngine wrapped model. @@ -309,6 +322,7 @@ def init_inference(model, moe_type, args, enable_cuda_graph, - save_mp_checkpoint_path) + save_mp_checkpoint_path, + base_dir) return engine diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 1c22960b1e31..a4b57a05f37b 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -51,7 +51,8 @@ def __init__(self, moe_type='standard', config=None, enable_cuda_graph=False, - save_mp_checkpoint_path=None): + save_mp_checkpoint_path=None, + base_dir=""): """ Args: model: torch.nn.Module @@ -100,6 +101,9 @@ def __init__(self, self.checkpoint_engine = TorchCheckpointEngine() self._init_quantization_setting(quantization_setting) + # This is a hack to remove the prepare_mask function on HF side for BLOOM architecture + self.remove_mask_prepare_for_bloom() + if enable_cuda_graph: assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ "If you want to use cuda graph, please upgrade torch to at least v1.10" @@ -135,7 +139,8 @@ def __init__(self, moe_type, training_mp_size, self.checkpoint if replace_with_kernel_inject else None, - save_mp_checkpoint_path=save_mp_checkpoint_path) + save_mp_checkpoint_path=save_mp_checkpoint_path, + base_dir=base_dir) elif replace_method == 'auto': self._apply_injection_policy( return_tuple=return_tuple, @@ -145,7 +150,8 @@ def __init__(self, moe_type=moe_type, training_mp_size=training_mp_size, checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None, - save_mp_checkpoint_path=save_mp_checkpoint_path) + save_mp_checkpoint_path=save_mp_checkpoint_path, + base_dir=base_dir) device = torch.cuda.current_device() self.module.to(device) @@ -165,6 +171,11 @@ def _get_model_config_generate(self, config): self.config = getattr(self.module, 'config', None) if config is None else config self.generate = getattr(self.module, 'generate', None) + def remove_mask_prepare_for_bloom(self): + if hasattr(self.module, 'transformer'): + if hasattr(self.module.transformer, '_prepare_attn_mask'): + self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask + def _create_model_parallel_group(self): # Call the init process if InferenceEngine.inference_mp_group is None: @@ -326,36 +337,37 @@ def _apply_injection_policy(self, moe_type='standard', training_mp_size=1, checkpoint_dir=None, - save_mp_checkpoint_path=False): + save_mp_checkpoint_path=False, + base_dir=""): checkpoint = SDLoaderFactory.get_sd_loader_json( checkpoint_dir, self.checkpoint_engine) if checkpoint_dir is not None else None - replace_transformer_layer( - client_module, - self.module, - triangular_masking=self.triangular_masking, - policy=injection_policy, - mp_size=self.mp_world_size, - mp_group=self.mp_group, - ep_group=self.ep_group, - expert_mp_group=self.expert_mp_group, - config=self.config, - fp16=(self.dtype == torch.half), - training=False, - return_tuple=return_tuple, - quantize=(self.dtype == torch.int8), - quantize_settings=(self.quantization_scales, - self.quantize_merge_count, - self.mlp_extra_grouping, - self.quantize_groups), - replace_with_kernel_inject=replace_with_kernel_inject, - moe=moe, - moe_experts=moe_experts, - moe_type=moe_type, - training_mp_size=training_mp_size, - checkpoint_dict=checkpoint, - save_mp_checkpoint_path=save_mp_checkpoint_path, - ) + replace_transformer_layer(client_module, + self.module, + triangular_masking=self.triangular_masking, + policy=injection_policy, + mp_size=self.mp_world_size, + mp_group=self.mp_group, + ep_group=self.ep_group, + expert_mp_group=self.expert_mp_group, + config=self.config, + fp16=(self.dtype == torch.half) + or (self.dtype == torch.int8), + training=False, + return_tuple=return_tuple, + quantize=(self.dtype == torch.int8), + quantize_settings=(self.quantization_scales, + self.quantize_merge_count, + self.mlp_extra_grouping, + self.quantize_groups), + replace_with_kernel_inject=replace_with_kernel_inject, + moe=moe, + moe_experts=moe_experts, + moe_type=moe_type, + training_mp_size=training_mp_size, + checkpoint_dict=checkpoint, + save_mp_checkpoint_path=save_mp_checkpoint_path, + base_dir=base_dir) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, @@ -450,7 +462,7 @@ def _choose_module_key(self, sd): return 'model' def _convert_to_dtype(self): - if self.dtype is torch.int8 and self.quantization_scales is None: + if False: #self.dtype is torch.int8 and self.quantization_scales is None: quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping) model, self.quantization_scales = quantizer.model_quantize(self.module, self.injection_dict, diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index e0f44675dfd7..560f1bc83bc7 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -3,9 +3,15 @@ from ..runtime.zero import GatheredParameters from .layers import LinearLayer, Normalize, EmbeddingLayer import torch +import gc -def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0): +def load_model_with_checkpoint(r_module, + sd, + mp_replace, + ckpt_type, + weight_quantizer=None, + rank=0): error_msgs = [] def transpose(data): @@ -15,7 +21,7 @@ def transpose(data): return data.reshape(data.shape[-1], data.shape[-2]) def load(module, prefix): - args = (sd, prefix, {}, True, [], [], error_msgs) + args = (sd[0], prefix, {}, True, [], [], error_msgs) if len(list(module.parameters())) > 0 and list( module.parameters())[0].numel() == 0: @@ -25,81 +31,142 @@ def load(module, prefix): else: if hasattr(module, 'weight'): module.weight = mp_replace.copy(module.weight.data, - sd[prefix + 'weight']) - if prefix + 'bias' in sd.keys(): - module.bias = mp_replace.copy(module.bias.data, sd[prefix + 'bias']) + sd[0][prefix + 'weight']) + if prefix + 'bias' in sd[0].keys(): + module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias']) + args = None + gc.collect() def load_transformer_layer(module, prefix): if ckpt_type == "tp": def load_parameters(module, prefix): for n, p in module.named_parameters(): - if len(n.split('.')) == 1: - src_shape = sd[prefix + n].shape + if prefix + n in sd[0] and len(n.split('.')) == 1: + if type(sd[0][prefix + n]) is list: + tmp_data, scale = sd[0][prefix + n] + tmp_data = tmp_data + scale = scale.to(torch.cuda.current_device()) + else: + tmp_data = sd[0][prefix + n].to(torch.cuda.current_device()) + scale = None + src_shape = tmp_data.shape dst_shape = p.shape - + inner_dim = 1 if tmp_data.dtype == torch.int8 else 0 + outer_dim = 0 if tmp_data.dtype == torch.int8 else 1 if (len(src_shape) == 2 and len(dst_shape) == 2): - if src_shape[0] == dst_shape[0] and src_shape[ - 1] == dst_shape[1]: - p.data.copy_(sd[prefix + n]) + if (src_shape[inner_dim] == dst_shape[0] + and src_shape[outer_dim] == dst_shape[1]): + if tmp_data.dtype != torch.int8: + p = weight_quantizer.quantize( + transpose(tmp_data) if weight_quantizer. + q_int8 else tmp_data) + else: + p = torch.nn.parameter.Parameter(tmp_data, + requires_grad=False) + p.scale = scale + setattr(module, n, p) else: - if src_shape[0] != dst_shape[0]: - weight_split = torch.split( - sd[prefix + n], - dst_shape[0], - dim=0)[rank].to( - torch.cuda.current_device()).contiguous() + dim = inner_dim if src_shape[inner_dim] != dst_shape[ + 0] else outer_dim + dim1 = 0 if src_shape[inner_dim] != dst_shape[0] else 1 + if src_shape[dim] > dst_shape[dim1]: + weight_partition = torch.split( + tmp_data, + dst_shape[dim1], + dim=dim)[rank].to(torch.cuda.current_device()) + assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \ + '''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\ + Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!''' + scale = scale.view( + -1)[weight_quantizer.num_groups * + (rank + 1):].reshape( + weight_quantizer.num_groups, + -1).contiguous() else: - weight_split = torch.split( - sd[prefix + n], - dst_shape[1], - dim=1)[rank].to( - torch.cuda.current_device()).contiguous() - p.data.copy_(weight_split.contiguous()) + assert tmp_data.dtype != torch.int8, \ + '''Merging of the checkpoints are not supported when using INT8 checkpoint! \ + Please use a as many GPUs as TP-size for the checkpoint''' + all_data = [ + sd[j][prefix + + n] if type(sd[j][prefix + n]) is list else + sd[j][prefix + n].to(torch.cuda.current_device()) + for j in range(len(sd)) + ] + weight_partition = torch.cat([ + ad[0].to(torch.cuda.current_device()) + if type(ad) is list else ad for ad in all_data + ], + dim=dim) + if tmp_data.dtype == torch.int8: + scale = torch.cat([ + ad[1].to(torch.cuda.current_device()) + for ad in all_data + ], + dim=dim) + + if tmp_data.dtype != torch.int8: + weight_partition = weight_quantizer.quantize( + transpose(weight_partition), \ + parallel_dim=(0 if dim == 1 else 1)) if weight_quantizer.q_int8 else \ + weight_quantizer.quantize(weight_partition) + else: + weight_partition = torch.nn.parameter.Parameter( + weight_partition, + requires_grad=False) + weight_partition.scale = scale + setattr(module, n, weight_partition) else: if src_shape[0] == dst_shape[0]: - p.data.copy_(sd[prefix + n]) + p.data.copy_(tmp_data) else: - bias_split = torch.split( - sd[prefix + n], - dst_shape[-1])[rank].to( - torch.cuda.current_device()).contiguous() - p.data.copy_(bias_split) + if src_shape[0] > dst_shape[0]: + bias_split = torch.split( + tmp_data, + dst_shape[-1])[rank].to( + torch.cuda.current_device()).contiguous() + p.data.copy_(bias_split) + else: + p.data.copy_( + torch.cat( + [sd[j][prefix + n] for j in range(len(sd))], + dim=0).to(torch.cuda.current_device()). + contiguous()) load_parameters(module, prefix) for n, child in module.named_children(): load_parameters(child, prefix + n + '.') else: - module.norm_w.data.copy_(sd[prefix + 'input_layernorm.' + 'weight']) - module.norm_b.data.copy_(sd[prefix + 'input_layernorm.' + 'bias']) - module.attention.attn_qkvw = mp_replace.copy( - module.attention.attn_qkvw.data, - transpose(sd[prefix + 'self_attention.query_key_value.' + 'weight'])) + module.norm_w.data.copy_(sd[0][prefix + 'input_layernorm.' + 'weight']) + module.norm_b.data.copy_(sd[0][prefix + 'input_layernorm.' + 'bias']) + module.attention.attn_qkvw = mp_replace.copy(module.attention.attn_qkvw, + weight_quantizer.quantize(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']) if weight_quantizer.q_int8 else \ + weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']))) module.attention.attn_qkvb = mp_replace.copy( module.attention.attn_qkvb.data, - sd[prefix + 'self_attention.query_key_value.' + 'bias']) - module.attention.attn_ow = mp_replace.copy( - module.attention.attn_ow.data, - transpose(sd[prefix + 'self_attention.dense.' + 'weight'])) + sd[0][prefix + 'self_attention.query_key_value.' + 'bias']) + module.attention.attn_ow = mp_replace.copy(module.attention.attn_ow, + weight_quantizer.quantize(sd[0][prefix + 'self_attention.dense.' + 'weight']) if weight_quantizer.q_int8 else \ + weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.dense.' + 'weight']))) module.attention.attn_ob = mp_replace.copy( module.attention.attn_ob.data, - sd[prefix + 'self_attention.dense.' + 'bias']) - module.mlp.attn_nw.data.copy_(sd[prefix + 'post_attention_layernorm.' + - 'weight']) - module.mlp.attn_nb.data.copy_(sd[prefix + 'post_attention_layernorm.' + - 'bias']) - module.mlp.inter_w = mp_replace.copy( - module.mlp.inter_w.data, - transpose(sd[prefix + 'mlp.dense_h_to_4h.' + 'weight'])) + sd[0][prefix + 'self_attention.dense.' + 'bias']) + module.mlp.attn_nw.data.copy_(sd[0][prefix + 'post_attention_layernorm.' + + 'weight']) + module.mlp.attn_nb.data.copy_(sd[0][prefix + 'post_attention_layernorm.' + + 'bias']) + module.mlp.inter_w = mp_replace.copy(module.mlp.inter_w, + weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']) if weight_quantizer.q_int8 else \ + weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']))) module.mlp.inter_b = mp_replace.copy( module.mlp.inter_b.data, - sd[prefix + 'mlp.dense_h_to_4h.' + 'bias']) - module.mlp.output_w = mp_replace.copy( - module.mlp.output_w.data, - transpose(sd[prefix + 'mlp.dense_4h_to_h.' + 'weight'])) + sd[0][prefix + 'mlp.dense_h_to_4h.' + 'bias']) + module.mlp.output_w = mp_replace.copy(module.mlp.output_w, + weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']) if weight_quantizer.q_int8 else \ + weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']))) module.mlp.output_b = mp_replace.copy( module.mlp.output_b.data, - sd[prefix + 'mlp.dense_4h_to_h.' + 'bias']) + sd[0][prefix + 'mlp.dense_4h_to_h.' + 'bias']) layer_policies = { nn.Linear: load, @@ -117,7 +184,7 @@ def load_module_recursive(module, prefix='', level=0): for name, child in module.named_children(): if child.__class__ in layer_policies: checking_key = prefix + name + '.' - if not any(checking_key in item for item in sd.keys()): + if not any(checking_key in item for item in sd[0].keys()): if hasattr(child, 'weight') and \ (hasattr(child.weight, 'ds_id') and \ child.weight.ds_id in all_ds_ids): @@ -168,6 +235,7 @@ def load_module_recursive(module, prefix='', level=0): embedding_weight = p assert hasattr(r_module, 'lm_head'), "attempting to set lm_head but it doesn't exist" r_module.lm_head.weight = embedding_weight - - del sd + for sd_ in sd: + del sd_ sd = None + gc.collect() diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index ae2bc0b9fa36..ccda3c8132b1 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -5,7 +5,7 @@ import deepspeed.ops.transformer as transformer_inference from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, BLOOMLayerPolicy from .replace_policy import replace_policies -from ..runtime.weight_quantizer import WeightQuantization +#from ..runtime.weight_quantizer import WeightQuantization from deepspeed import comm as dist from torch import nn @@ -115,8 +115,10 @@ def copy(self, dst, src): dst_shape[-1])[self.gpu_index].to( torch.cuda.current_device()).contiguous() dst.data.copy_(bias_split) - - return torch.nn.parameter.Parameter(dst, requires_grad=False) + dst = torch.nn.parameter.Parameter(dst, requires_grad=False) + if hasattr(src, 'scale'): + dst.scale = src.scale + return dst def get_transformer_name(replaced_module): @@ -134,6 +136,57 @@ def get_transformer_name(replaced_module): return transformer_name +class GroupQuantizer: + def __init__(self, q_int8=True, num_groups=32, group_size=32, num_bits=8): + self.num_groups = num_groups + self.group_size = group_size + self.num_bits = num_bits + self.q_int8 = q_int8 + + def quantize(self, inputs, qkv=True, count=1, parallel_dim=0): + if not self.q_int8 or not qkv: + inputs = torch.nn.Parameter(inputs, requires_grad=False) + inputs.scale = torch.empty(1) + return inputs + q_range = 2**self.num_bits + inputs = inputs.to(torch.cuda.current_device()) + input_flat = inputs.reshape(self.num_groups, -1).contiguous() + input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float() + input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float() + scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range) + input_flat = (input_flat / scale).round().clamp(-q_range // 2, q_range // 2 - 1) + inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous() + out = torch.nn.Parameter(inputs_q, requires_grad=False) + #print(inputs.shape) + inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim) + input_flat = [ + inputs_split[i].reshape(self.num_groups, + -1).contiguous() for i in range(2) + ] + input_min = [ + torch.min(input_flat[i], + dim=1, + keepdim=True)[0].float() for i in range(2) + ] + input_max = [ + torch.max(input_flat[i], + dim=1, + keepdim=True)[0].float() for i in range(2) + ] + scale1 = [ + (torch.max(input_min[i].abs(), + input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0) + for i in range(2) + ] + + out.scale = torch.cat([scale.squeeze().unsqueeze(0), + scale1[0], + scale1[1]], + dim=0).reshape(self.num_groups, + -1).contiguous() + return out + + def replace_transformer_layer(orig_layer_impl, model, policy=None, @@ -161,7 +214,8 @@ def replace_transformer_layer(orig_layer_impl, moe_experts=1, moe_type='standard', checkpoint_dict=None, - save_mp_checkpoint_path=None): + save_mp_checkpoint_path=None, + base_dir=""): """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, @@ -225,7 +279,7 @@ def replace_with_policy(child, _res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type) attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm() - if quantize: + if False: if policy_cls is not HFBertLayerPolicy: qkvw = qkvw.to(torch.int8) dense_w = dense_w.to(torch.int8) @@ -257,6 +311,7 @@ def replace_with_policy(child, #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group) + quantizer = GroupQuantizer(q_int8=quantize) if inference: if moe: ep_world_size = dist.get_world_size() @@ -329,21 +384,21 @@ def replace_with_policy(child, new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, - quantize_scales=quantization_scales[layer_id], + #quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) - if quantize and qkvw.dtype != torch.int8: - quantize_bits = 8 - quantizer = WeightQuantization() - if policy_cls is HFBertLayerPolicy: - data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3) - else: - data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups) - qkvw.data.copy_(data_quantized) - qkvw.data = qkvw.data.to(torch.int8) + #if quantize and qkvw.dtype != torch.int8: + # quantize_bits = 8 + # quantizer = WeightQuantization() + # if policy_cls is HFBertLayerPolicy: + # data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3) + # else: + # data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups) + # qkvw.data.copy_(data_quantized) + # qkvw.data = qkvw.data.to(torch.int8) else: if moe: @@ -478,18 +533,17 @@ def _transpose(x): attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) else: - if bigscience_bloom: - attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw) - attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb) - else: - attn_block.attn_qkvw = mp_replace.qkv_copy( - attn_block.attn_qkvw, - qkvw) - attn_block.attn_qkvb = mp_replace.qkv_copy( - attn_block.attn_qkvb, - qkvb) - - attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) + attn_block.attn_qkvw = quantizer.quantize( + mp_replace.copy(attn_block.attn_qkvw, qkvw) if bigscience_bloom else \ + mp_replace.qkv_copy(attn_block.attn_qkvw, qkvw)) + attn_block.attn_qkvb = \ + mp_replace.copy(attn_block.attn_qkvb, qkvb) if bigscience_bloom else \ + mp_replace.qkv_copy(attn_block.attn_qkvb, qkvb) + + attn_block.attn_ow = quantizer.quantize( + mp_replace.copy(attn_block.attn_ow, + dense_w)) + attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) if moe: @@ -545,9 +599,13 @@ def _transpose(x): mpl_block.output_b, _4hh_b) else: - mpl_block.inter_w = mp_replace.copy(mpl_block.inter_w, _h4h_w) + mpl_block.inter_w = quantizer.quantize( + mp_replace.copy(mpl_block.inter_w, + _h4h_w)) mpl_block.inter_b = mp_replace.copy(mpl_block.inter_b, _h4h_b) - mpl_block.output_w = mp_replace.copy(mpl_block.output_w, _4hh_w) + mpl_block.output_w = quantizer.quantize( + mp_replace.copy(mpl_block.output_w, + _4hh_w)) mpl_block.output_b = mp_replace.copy(mpl_block.output_b, _4hh_b) if attn_nw is None: @@ -782,50 +840,92 @@ def replace_fn(child, _policy, layer_id=0): replace_fn=replace_fn, _replace_policy=policy) + quantizer = GroupQuantizer(q_int8=quantize) world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 if checkpoint_dict is not None: start_time = time.time() checkpoint = checkpoint_dict['checkpoints'] + ckpt_list = checkpoint["tp"] if type(checkpoint) is dict else checkpoint ckpt_type = checkpoint_dict.get('parallelization', 'pp') - ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size) - base_dir = checkpoint_dict.get('base_dir', '') + ckpt_mp_size = checkpoint_dict.get('tp_size', len(ckpt_list)) + ckpt_mp_size = checkpoint_dict.get('mp_size', ckpt_mp_size) + base_dir1 = checkpoint_dict.get('base_dir', base_dir) - if ckpt_type == 'pp': + if ckpt_type == 'pp' and type(checkpoint) is list: pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") + for i in range(len(checkpoint)): - if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: - pbar.update(1) - sd = torch.load(checkpoint[i], map_location='cpu') - load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type) + + sd = [ + torch.load(os.path.join(base_dir1, + checkpoint[i]), + map_location='cpu') + ] + load_model_with_checkpoint( + replaced_module, + sd, + mp_replace, + ckpt_type, + quantizer, + ) else: - num_checkpoints = len(checkpoint) // ckpt_mp_size - assert world_size >= ckpt_mp_size,\ - "Currently, merging checkpoints is not supported (when world_size is smaller than #checkpoints)!" - checkpoint_stride = world_size // ckpt_mp_size - if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: - pbar = tqdm.tqdm(total=num_checkpoints, - desc=f"Loading {num_checkpoints} checkpoint shards") + import gc + num_checkpoints = len(ckpt_list) // ckpt_mp_size + tp_split_size = (world_size / ckpt_mp_size) + sd_offset = int(rank / tp_split_size) + sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset + pbar = tqdm.tqdm(total=num_checkpoints, + desc=f"Loading {num_checkpoints} checkpoint shards") for i in range(num_checkpoints): - if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: - pbar.update(1) - - ckpt_index = i * ckpt_mp_size + (rank // checkpoint_stride) - ckpt_file = os.path.join( - base_dir, - checkpoint[ckpt_index]) if base_dir else checkpoint[ckpt_index] - sd = torch.load(ckpt_file, map_location='cpu') + pbar.update(1) + ckpt_index = i * ckpt_mp_size + sd_offset + ckpt_files = [ + os.path.join(base_dir1, + ckpt_list[ckpt_index + + j]) if base_dir1 else ckpt_list[ckpt_index + + j] + for j in range(sd_count) + ] + sds = [ + torch.load(ckpt_file, + map_location='cpu') for ckpt_file in ckpt_files + ] load_model_with_checkpoint(replaced_module, - sd, + sds, mp_replace, ckpt_type, - rank % (world_size // ckpt_mp_size)) + quantizer, + int(rank % tp_split_size)) + sds = [None for _ in sds] + gc.collect() + + if "non_tp" in checkpoint: + pbar = tqdm.tqdm( + total=len(checkpoint["non_tp"]), + desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards") + + for i in range(len(checkpoint["non_tp"])): + pbar.update(1) + ckpt_file = os.path.join(base_dir1, + checkpoint["non_tp"][i] + ) if base_dir1 else checkpoint["non_tp"][i] + sds = [torch.load(ckpt_file, map_location='cpu')] + load_model_with_checkpoint(replaced_module, + sds, + mp_replace, + ckpt_type, + quantizer, + int(rank % tp_split_size)) + sds = [None for _ in sds] + gc.collect() print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") if save_mp_checkpoint_path is not None: from collections import OrderedDict import json + num_partitions = 8 if checkpoint_dict is None: ckpt_name = "ds_model" @@ -840,8 +940,8 @@ def replace_fn(child, _policy, layer_id=0): if dist.is_initialized(): dist.barrier() transformer_name = get_transformer_name(replaced_module) - non_tp_ckpt_name = f'{ckpt_name}-non-tp.pt' - ckpt_files = [non_tp_ckpt_name] * world_size + non_tp_ckpt_name = f'non-tp.pt' + ckpt_files = [non_tp_ckpt_name] os.makedirs(save_mp_checkpoint_path, exist_ok=True) if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") @@ -853,25 +953,47 @@ def replace_fn(child, _policy, layer_id=0): if transformer_name not in k }), f'{save_mp_checkpoint_path}/{non_tp_ckpt_name}') - ckpt_files += [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)] config = json.dumps({ - 'type': ckpt_name, - 'base_dir': f'{save_mp_checkpoint_path}', - 'checkpoints': ckpt_files, - 'version': 1.0, - 'parallelization': 'tp', - 'mp_size': world_size + 'type': + ckpt_name, + 'base_dir': + f'{save_mp_checkpoint_path}', + 'checkpoints': { + "non_tp": + ckpt_files, + "tp": [ + f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions) + for r in range(world_size) + ] + }, + 'version': + 1.0, + 'parallelization': + 'tp', + 'tp_size': + world_size, + 'dtype': + 'int8' if quantize else ('float16' if fp16 else 'float32') }) - with open(f"{save_mp_checkpoint_path}/{ckpt_name}_ds-inference_config.json", - "w") as cfg: + with open(f"{save_mp_checkpoint_path}/ds-inference_config.json", "w") as cfg: cfg.write(config) - torch.save( - OrderedDict({ - k: v - for k, - v in dict(replaced_module.state_dict()).items() if transformer_name in k - }), - f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{rank:0>2d}.pt') + + rep_sd = replaced_module.state_dict() + for n, p in replaced_module.named_parameters(): + if hasattr(p, 'scale'): + rep_sd[n] = [p, p.scale] + keys = list(rep_sd.keys()) + partition_size = (len(keys) // num_partitions + 1) + for m in range(num_partitions): + torch.save( + OrderedDict({ + k: [rep_sd[k], + rep_sd[k].scale] if hasattr(rep_sd[k], + 'scale') else rep_sd[k] + for k in keys[m * partition_size:(m + 1) * partition_size] + if transformer_name in k + }), + f'{save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt') return replaced_module diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index f03cc1248578..fa28a34f04a2 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -206,16 +206,6 @@ def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] - key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) - value_layer = torch.cat((past_value.type_as(value_layer), - value_layer), - dim=1) - - presents = (key_layer, value_layer) - # [batch_size, head_dim, q_length, k_length] output_size = (query_layer.size(0), query_layer.size(2), @@ -223,24 +213,37 @@ def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): key_layer.size(1)) # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] query_layer = query_layer.transpose(1, - 0).reshape( - output_size[2], + 2).reshape( output_size[0] * output_size[1], + output_size[2], -1) # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] key_layer = key_layer.transpose(1, - 0).reshape(output_size[3], - output_size[0] * output_size[1], - -1) + 2).reshape(output_size[0] * output_size[1], + output_size[3], + -1).transpose(-1, + -2) + value_layer = value_layer.transpose(1, + 2).reshape( + output_size[0] * output_size[1], + output_size[3], + -1) + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1) + value_layer = torch.cat((past_value.type_as(value_layer), + value_layer), + dim=-2) + presents = (key_layer, value_layer) # Raw attention scores. [batch_size * num_heads, q_length, k_length] - matmul_result = torch.matmul(query_layer.transpose(1, - 0), - key_layer.transpose(1, - 0).transpose(1, - 2)) + matmul_result = torch.matmul(query_layer, key_layer) # change view to [batch_size, num_heads, q_length, k_length] - attention_scores = matmul_result.view(*output_size) + attention_scores = matmul_result.view(output_size[0], + output_size[1], + output_size[2], + -1) offset = dist.get_rank( ) * num_attention_heads_per_partition if dist.is_initialized() else 0 @@ -261,12 +264,7 @@ def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): attention_probs_reshaped = attention_probs.view(*matmul_result.shape) # matmul: [batch_size * num_heads, q_length, head_dim] - context_layer = torch.bmm( - attention_probs_reshaped, - value_layer.transpose(1, - 2).reshape(-1, - value_layer.size(1), - value_layer.size(3))) + context_layer = torch.bmm(attention_probs_reshaped, value_layer) # change view [batch_size, num_heads, q_length, head_dim] context_layer = context_layer.view( @@ -418,15 +416,21 @@ def selfAttention_fp(): qkv_out = qkv_func( input, attn_qkvw, + attn_qkvw.scale, (attn_qkvb if attn_qkvb is not None else norm_b), norm_w, norm_b, config.epsilon, (attn_qkvb is not None), 1 if config.bigscience_bloom else - DeepSpeedTransformerInference.layer_id) + DeepSpeedTransformerInference.layer_id, + config.q_int8) context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask) - output = vector_matmul_func(context_layer, attn_ow, False) + output = vector_matmul_func(context_layer, + attn_ow, + False, + attn_ow.scale, + config.q_int8) return output, key_layer, value_layer, context_layer, qkv_out[-1] @@ -458,7 +462,7 @@ def selfAttention_int8(): (merge_count)) return output, key_layer, value_layer, context_layer - if config.q_int8: + if False: #config.q_int8: output, key_layer, value_layer, context_layer = selfAttention_int8() else: output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp() @@ -486,30 +490,34 @@ def __init__(self, qkv_merging=False): super(DeepSpeedSelfAttention, self).__init__() self.config = config - data_type = torch.half if config.fp16 else torch.float + data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float + data_type_fp = torch.half if config.fp16 else torch.float self.config.layer_id = DeepSpeedSelfAttention.num_layers DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' - self.attn_qkvw = nn.Parameter( - torch.empty(self.config.hidden_size, - (self.config.hidden_size // self.config.mp_size) * 3, - dtype=data_type, - device=device)) - self.attn_qkvb = nn.Parameter( - torch.empty((self.config.hidden_size // self.config.mp_size) * 3, - dtype=data_type, - device=device)) - - self.attn_ow = nn.Parameter( - torch.empty(self.config.hidden_size // self.config.mp_size, - self.config.hidden_size, - dtype=data_type, - device=device)) - - self.attn_ob = nn.Parameter( - torch.empty(self.config.hidden_size, - dtype=data_type, - device=device)) + self.attn_qkvw = nn.Parameter(torch.empty( + self.config.hidden_size, + (self.config.hidden_size // self.config.mp_size) * 3, + dtype=data_type, + device=device), + requires_grad=False) + self.attn_qkvb = nn.Parameter(torch.empty( + (self.config.hidden_size // self.config.mp_size) * 3, + dtype=data_type_fp, + device=device), + requires_grad=False) + + self.attn_ow = nn.Parameter(torch.empty(self.config.hidden_size // + self.config.mp_size, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + + self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type_fp, + device=device), + requires_grad=False) self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size @@ -595,36 +603,16 @@ def forward(ctx, bias_residual_func, activation_func_type=ActivationFuncType.GELU): - if config.q_int8: - (intermediate, - residual_add) = inference_cuda_module.mlp_gemm_int8( - input, - residual, - bias, - inter_w, - inter_b, - attn_nw, - attn_nb, - config.epsilon, - q_scales[2], - (q_groups * (2**merge_count)), - config.pre_layer_norm) - output = inference_cuda_module.vector_matmul_int8(intermediate, - output_w, - q_scales[3], - q_groups, - (merge_count)) + if attn_nw is None: + output = fused_gemm_gelu(residual_norm, + inter_w, + inter_b, + output_w, + config.epsilon, + config.pre_layer_norm, + False) else: - if attn_nw is None: - output = fused_gemm_gelu(residual_norm, - inter_w, - inter_b, - output_w, - config.epsilon, - config.pre_layer_norm, - False) - else: - intermediate, residual_add = mlp_gemm_func(input, + intermediate, residual_add = mlp_gemm_func(input, residual, bias, inter_w, @@ -634,9 +622,14 @@ def forward(ctx, config.epsilon, config.pre_layer_norm, config.mlp_after_attn, + inter_w.scale, + config.q_int8, config.mlp_act_func_type) - output = vector_matmul_func(intermediate, output_w, False) - + output = vector_matmul_func(intermediate, + output_w, + False, + output_w.scale, + config.q_int8) inference_cuda_module.residual_add( output, residual if config.pre_layer_norm else residual_add, @@ -668,34 +661,38 @@ def __init__(self, super(DeepSpeedMLP, self).__init__() self.config = config - data_type = torch.half if config.fp16 else torch.float + data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float + data_type_fp = torch.half if config.fp16 else torch.float device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' - self.attn_nw = nn.Parameter( - torch.empty(self.config.hidden_size, - dtype=data_type, - device=device)) - self.attn_nb = nn.Parameter( - torch.empty(self.config.hidden_size, - dtype=data_type, - device=device)) - self.inter_w = nn.Parameter( - torch.empty(self.config.hidden_size, - self.config.intermediate_size // self.config.mp_size, - dtype=data_type, - device=device)) - self.inter_b = nn.Parameter( - torch.empty(self.config.intermediate_size // self.config.mp_size, - dtype=data_type, - device=device)) - self.output_w = nn.Parameter( - torch.empty((self.config.intermediate_size // self.config.mp_size), - self.config.hidden_size, - dtype=data_type, - device=device)) - self.output_b = nn.Parameter( - torch.empty(self.config.hidden_size, - dtype=data_type, - device=device)) + self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type_fp, + device=device), + requires_grad=False) + self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type_fp, + device=device), + requires_grad=False) + self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size, + self.config.intermediate_size // + self.config.mp_size, + dtype=data_type, + device=device), + requires_grad=False) + self.inter_b = nn.Parameter(torch.empty(self.config.intermediate_size // + self.config.mp_size, + dtype=data_type_fp, + device=device), + requires_grad=False) + self.output_w = nn.Parameter(torch.empty( + (self.config.intermediate_size // self.config.mp_size), + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type_fp, + device=device), + requires_grad=False) # used for quantization self.q_scales = q_scales @@ -790,14 +787,14 @@ def __init__(self, mlp_extra_grouping) device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' - self.norm_w = nn.Parameter( - torch.empty(self.config.hidden_size, - dtype=data_type, - device=device)) - self.norm_b = nn.Parameter( - torch.empty(self.config.hidden_size, - dtype=data_type, - device=device)) + self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) self.layer_past = None def forward( @@ -826,7 +823,6 @@ def forward( # We set the prev key/value to None when there is a prompt if input.shape[1] > 1: self.layer_past = None - layer_past = layer_past if layer_past is not None else self.layer_past head_mask = layer_head_mask if layer_head_mask is not None else head_mask