From cf2fe01107dec7652aa38e9c0d7384cf2c7c2205 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 9 Aug 2022 04:02:24 +0500 Subject: [PATCH 01/16] Fix the layer-past for GPT based models --- deepspeed/ops/transformer/inference/transformer_inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index df65fb317e9b..d38cf8c3d395 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -812,6 +812,10 @@ def forward(self, output_attentions=False): get_present = (get_present or get_key_value or use_cache) input_mask = input_mask if attention_mask is None else attention_mask + + # We set the prev key/value to None when there is a prompt + if input.shape[1] > 1: + self.layer_past = None layer_past = layer_past if layer_past is not None else self.layer_past attn_mask = None From c2cf304c86e613004e8212d9dd4be486d40b94de Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sat, 13 Aug 2022 06:10:51 +0500 Subject: [PATCH 02/16] add the Int8 support for ds-inference using ZeroQuant technology --- csrc/transformer/inference/csrc/dequantize.cu | 87 +++++++ csrc/transformer/inference/csrc/gelu.cu | 176 ++++++++++++++ .../transformer/inference/csrc/pt_binding.cpp | 229 +++++++++++------- .../inference/includes/custom_cuda_layers.h | 27 +++ deepspeed/inference/engine.py | 4 +- deepspeed/module_inject/load_checkpoint.py | 126 ++++++---- deepspeed/module_inject/replace_module.py | 147 +++++++---- .../inference/transformer_inference.py | 183 +++++++------- 8 files changed, 709 insertions(+), 270 deletions(-) diff --git a/csrc/transformer/inference/csrc/dequantize.cu b/csrc/transformer/inference/csrc/dequantize.cu index 4ddaabda3eb7..3409f7ba7de8 100644 --- a/csrc/transformer/inference/csrc/dequantize.cu +++ b/csrc/transformer/inference/csrc/dequantize.cu @@ -108,3 +108,90 @@ template void launch_dequantize<__half>(__half*, unsigned, unsigned, cudaStream_t); + +__global__ void dequantize_kernel(float* output, + const int8_t* input, + const float* qscale, + int hidden_dim, + unsigned merge_hidden, + int cnt) +{ +} + +__global__ void dequantize_kernel(__half* output, + const int8_t* input, + const float* qscale, + unsigned hidden_dim, + unsigned merge_hidden, + int cnt) +{ + unsigned bid = blockIdx.x * gridDim.y + blockIdx.y; + unsigned tid = threadIdx.x; + + float local_scale = qscale[blockIdx.x]; + + const float* input_cast = reinterpret_cast(input); + float2* output_cast = reinterpret_cast(output); + + input_cast += bid * merge_hidden; + output_cast += bid * merge_hidden; + + for (int c = 0; c < cnt; c++) { + if (tid < merge_hidden) { + float q = input_cast[tid]; + int8_t* q_int8 = (int8_t*)&q; + + float2 q_f; + __half* q_h = (__half*)&q_f; + + q_h[0] = __float2half(local_scale * (float)q_int8[0]); + q_h[1] = __float2half(local_scale * (float)q_int8[1]); + q_h[2] = __float2half(local_scale * (float)q_int8[2]); + q_h[3] = __float2half(local_scale * (float)q_int8[3]); + // q_h[4] = __float2half(local_scale * (float)q_int8[4]); + // q_h[5] = __float2half(local_scale * (float)q_int8[5]); + // q_h[6] = __float2half(local_scale * (float)q_int8[6]); + // q_h[7] = __float2half(local_scale * (float)q_int8[7]); + output_cast[tid] = q_f; + tid += blockDim.x; + } + } +} + +template +void launch_dequantize(T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + cudaStream_t stream) +{ + unsigned threads = 1024; + hidden_dim /= 4; + unsigned hid_cnt = threads / hidden_dim; + unsigned thd_cnt = (hidden_dim - 1) / threads + 1; + hid_cnt = hid_cnt > 0 ? hid_cnt : 1; + + unsigned blocks = output_size / hid_cnt / groups; + dim3 block_dims(threads); + dim3 grid_dims(groups, blocks); + + dequantize_kernel<<>>( + output, input, qscale, hidden_dim, hid_cnt * hidden_dim, thd_cnt); +} + +template void launch_dequantize(float*, + const int8_t*, + const float*, + unsigned, + unsigned, + unsigned, + cudaStream_t); +template void launch_dequantize<__half>(__half*, + const int8_t*, + const float*, + unsigned, + unsigned, + unsigned, + cudaStream_t); diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index c3d65fa037c2..f0ba2e2c07d9 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -1,5 +1,6 @@ #include "custom_cuda_layers.h" +namespace cg = cooperative_groups; #define MAX_CAP 4 #define MAX_SEQ 2048 @@ -537,3 +538,178 @@ template void launch_moe_res_matmul(__half* residual, int seq_len, int hidden_dim, cudaStream_t stream); + +__device__ void quantize_kernel_glue(float2* data, + unsigned cnt, + int8_t* vals_int, + float* q_scale_d, + int num_bits, + int group_size) +{ + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int gid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + int id = threadIdx.x; + + float* vals_int_cast = reinterpret_cast(vals_int); + + __half max = -10000.0; + int bid = blockIdx.x; + unsigned group_index; + for (int i = 0; i < cnt; i++) { + __half* data_h = reinterpret_cast<__half*>(&data[i]); + if (__hgt(__habs(data_h[0]), max)) max = __habs(data_h[0]); + if (__hgt(__habs(data_h[1]), max)) max = __habs(data_h[1]); + if (__hgt(__habs(data_h[2]), max)) max = __habs(data_h[2]); + if (__hgt(__habs(data_h[3]), max)) max = __habs(data_h[3]); + } + +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(max, i); + if (__hgt(temp, max)) max = temp; + } + __shared__ __half partialMax[WARP_SIZE]; + + if (lane == 0) partialMax[gid] = max; + + b.sync(); + + max = partialMax[lane]; + + b.sync(); + +#pragma unroll + for (int i = 1; i < warp_num; i <<= 1) { + auto temp = g.shfl_xor(max, i); + if (__hgt(temp, max)) max = temp; + } + max = g.shfl(max, 0); + + float q_scale = (1 << num_bits) / (2 * (float)max); + + group_index = threadIdx.x + bid * group_size; + for (int i = 0; i < cnt; i++) { + float q_data_int; // = (float)(int)(1 << 8 | 1 << 16 | 1 << 24 | 1); + int8_t* q_data_8 = reinterpret_cast(&q_data_int); + __half* data_h = reinterpret_cast<__half*>(&data[i]); + int32_t data_f[4]; + data_f[0] = round((float)data_h[0] * q_scale); + data_f[1] = round((float)data_h[1] * q_scale); + data_f[2] = round((float)data_h[2] * q_scale); + data_f[3] = round((float)data_h[3] * q_scale); + q_data_8[0] = data_f[0] > 127 ? 127 : (data_f[0] < -128 ? -128 : data_f[0]); + q_data_8[1] = data_f[1] > 127 ? 127 : (data_f[1] < -128 ? -128 : data_f[1]); + q_data_8[2] = data_f[2] > 127 ? 127 : (data_f[2] < -128 ? -128 : data_f[2]); + q_data_8[3] = data_f[3] > 127 ? 127 : (data_f[3] < -128 ? -128 : data_f[3]); + vals_int_cast[group_index] = q_data_int; + group_index += (blockDim.x); + } + if (threadIdx.x == 0) q_scale_d[blockIdx.x] = 1 / q_scale; +} +__global__ void fused_bias_gelu_int8(int8_t* output, + float* scales, + __half* input, + const __half* bias, + int total_count, + int intermediate_size) +{ +#if __CUDA_ARCH__ >= 700 + + float2* input_cast = reinterpret_cast(input); + const float2* bias_cast = reinterpret_cast(bias); + + int offset = blockIdx.x * intermediate_size; + int id = threadIdx.x; + float2 vals_vec[8]; + unsigned cnt = 0; + while (id < intermediate_size) { + vals_vec[cnt] = input_cast[offset + id]; + float2 bias_vec = bias_cast[id]; + + __half2* vals_half = reinterpret_cast<__half2*>(vals_vec + cnt); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + low_data.x += low_bias.x; + low_data.y += low_bias.y; + high_data.x += high_bias.x; + high_data.y += high_bias.y; + + low_data.x = gelu(low_data.x); + low_data.y = gelu(low_data.y); + high_data.x = gelu(high_data.x); + high_data.y = gelu(high_data.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + // input_cast[offset + id] = vals_vec; + id += blockDim.x; + cnt++; + } + quantize_kernel_glue(vals_vec, cnt, output, scales, 8, intermediate_size); +#endif +} +__global__ void quantize_int8(int8_t* output, + float* scales, + __half* input, + int total_count, + int intermediate_size) +{ + float2* input_cast = reinterpret_cast(input); + + int offset = blockIdx.x * intermediate_size; + int id = threadIdx.x; + float2 vals_vec[8]; + unsigned cnt = 0; + while (id < intermediate_size) { + vals_vec[cnt] = input_cast[offset + id]; + + id += blockDim.x; + cnt++; + } + quantize_kernel_glue(vals_vec, cnt, output, scales, 8, intermediate_size); +} + +void launch_bias_gelu_int8(int8_t* output, + float* scales, + __half* input, + const __half* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int total_count = batch_size * (intermediate_size / 4); + int threads = 1024; // intermediate_size / iterations / 4; + dim3 block_dims(threads); + dim3 grid_dims(batch_size); // (batch_size); + + fused_bias_gelu_int8<<>>( + output, scales, input, bias, total_count, intermediate_size / 4); +} + +void launch_me(int8_t* output, + float* scales, + __half* input, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int total_count = batch_size * (intermediate_size / 4); + int threads = 1024; // intermediate_size / iterations / 4; + dim3 block_dims(threads); + dim3 grid_dims(batch_size); // (batch_size); + + quantize_int8<<>>( + output, scales, input, total_count, intermediate_size / 4); +} diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 360beaa10284..3f1d5b935abd 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -537,15 +537,55 @@ void ds_layernorm_internal(T* workspace, Context::Instance().GetCurrentStream()); } +template +void quantized_gemm(at::Tensor& output, + T* input, + at::Tensor& weight, + at::Tensor& qscale, + int groups, + int bsz) +{ + auto weight16 = at::empty({weight.size(0), weight.size(1)}, output.options()); + + launch_dequantize((T*)weight16.data_ptr(), + (int8_t*)weight.data_ptr(), + (float*)qscale.data_ptr(), + weight.size(0), + weight.size(1), + groups, + Context::Instance().GetCurrentStream()); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_T, + CUBLAS_OP_N, + weight.size(0), + bsz, + weight.size(1), + &alpha, + &gemm_beta, + (T*)weight16.data_ptr(), + (T*)input, + (T*)output.data_ptr(), +#ifdef __HIP_PLATFORM_HCC__ + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif +} + template at::Tensor qkv_unfused_cublas(at::Tensor& output, at::Tensor& input, at::Tensor& weight, + at::Tensor& q_scale, at::Tensor& bias, at::Tensor& gamma, at::Tensor& beta, const float epsilon, - bool add_bias) + bool add_bias, + bool q_int8) { int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); @@ -553,48 +593,55 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, ds_layernorm_internal(workspace, input, gamma, beta, epsilon); // cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream()); - float alpha = (T)1.0; - float gemm_beta = (T)0.0; + if (q_int8) { + quantized_gemm(output, workspace, weight, q_scale, q_scale.size(0), bsz); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - workspace, - (T*)output.data_ptr(), + cublasSetStream(Context::Instance().GetCublasHandle(), + Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + workspace, + (T*)output.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ - rocblas_gemm_algo_standard); + rocblas_gemm_algo_standard); #else - CUBLAS_GEMM_DEFAULT_TENSOR_OP); + CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif + } if (add_bias) launch_bias_add((T*)output.data_ptr(), (T*)bias.data_ptr(), weight.size(1), bsz, Context::Instance().GetCurrentStream()); - return torch::from_blob(workspace, input.sizes(), input.options()); } template std::vector ds_qkv_gemm(at::Tensor& input, at::Tensor& weight, + at::Tensor& q_scale, at::Tensor& bias, at::Tensor& gamma, at::Tensor& beta, const float epsilon, bool add_bias, - unsigned num_layers) + unsigned num_layers, + bool q_int8) { int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); + int out_size = q_int8 ? weight.size(0) : weight.size(1); if (!workspace) { cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); @@ -607,9 +654,9 @@ std::vector ds_qkv_gemm(at::Tensor& input, .device(at::kCUDA) .requires_grad(false); - auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); - auto inp_norm = - qkv_unfused_cublas(output, input, weight, bias, gamma, beta, epsilon, add_bias); + auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); + auto inp_norm = qkv_unfused_cublas( + output, input, weight, q_scale, bias, gamma, beta, epsilon, add_bias, q_int8); return {output, inp_norm}; } @@ -633,20 +680,18 @@ void quantized_gemm(at::Tensor& output, launch_dequantize((T*)weight16.data_ptr(), (int8_t*)weight.data_ptr(), (float*)qscale.data_ptr(), - weight.size(1), weight.size(0), + weight.size(1), groups, merge_count, Context::Instance().GetCurrentStream()); - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - float alpha = (T)1.0; float gemm_beta = (T)0.0; cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_T, CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), + weight.size(0), bsz, input.size(2), &alpha, @@ -775,7 +820,11 @@ at::Tensor ds_linear_layer_int8(at::Tensor& input, } template -at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op) +at::Tensor ds_vector_matmul(at::Tensor& input, + at::Tensor& weight, + bool async_op, + at::Tensor& q_scale, + bool q_int8) { auto input_cont = input.contiguous(); auto options = at::TensorOptions() @@ -784,28 +833,33 @@ at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op .device(at::kCUDA) .requires_grad(false); - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + int out_size = q_int8 ? weight.size(0) : weight.size(1); int bsz = input_cont.size(0) * input_cont.size(1); - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), - Context::Instance().GetCurrentStream(async_op)); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input_cont.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)output.data_ptr(), + auto output = at::empty({input_cont.size(0), input_cont.size(1), out_size}, options); + if (q_int8) { + quantized_gemm(output, (T*)input_cont.data_ptr(), weight, q_scale, q_scale.size(0), bsz); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), + Context::Instance().GetCurrentStream(async_op)); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input_cont.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)output.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ - rocblas_gemm_algo_standard); + rocblas_gemm_algo_standard); #else - CUBLAS_GEMM_DEFAULT_TENSOR_OP); + CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif + } return output; } @@ -840,7 +894,9 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, at::Tensor& beta, const float epsilon, bool preLayerNorm, - bool mlp_after_attn) + bool mlp_after_attn, + at::Tensor& q_scale, + bool q_int8) { int bsz = input.size(0) * input.size(1); auto inp_norm = at::empty_like(input); @@ -859,30 +915,40 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, mlp_after_attn, Context::Instance().GetCurrentStream()); - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)inp_norm.data_ptr(), - (T*)output.data_ptr(), + if (q_int8) { + quantized_gemm(output, (T*)inp_norm.data_ptr(), weight, q_scale, q_scale.size(0), bsz); + launch_bias_gelu((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(0), + bsz, + Context::Instance().GetCurrentStream()); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), + Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)inp_norm.data_ptr(), + (T*)output.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ - rocblas_gemm_algo_standard); + rocblas_gemm_algo_standard); #else - CUBLAS_GEMM_DEFAULT_TENSOR_OP); + CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif - launch_bias_gelu((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); + launch_bias_gelu((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + } return inp_norm; } template @@ -895,7 +961,9 @@ std::vector ds_mlp_gemm(at::Tensor& input, at::Tensor& beta, const float epsilon, bool preLayerNorm, - bool mlp_after_attn) + bool mlp_after_attn, + at::Tensor& q_scale, + bool q_int8) { auto input_cont = input.contiguous(); auto options = at::TensorOptions() @@ -904,7 +972,10 @@ std::vector ds_mlp_gemm(at::Tensor& input, .device(at::kCUDA) .requires_grad(false); - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + int out_size = q_int8 ? weight.size(0) : weight.size(1); + auto output = at::from_blob((T*)Context::Instance().GetWorkSpace(), + {input_cont.size(0), input_cont.size(1), out_size}, + options); int bsz = input_cont.size(0) * input_cont.size(1); auto res_add = mlp_unfused_cublas(output, @@ -917,7 +988,9 @@ std::vector ds_mlp_gemm(at::Tensor& input, beta, epsilon, preLayerNorm, - mlp_after_attn); + mlp_after_attn, + q_scale, + q_int8); return {output, res_add}; } @@ -948,20 +1021,6 @@ std::vector ds_mlp_gemm_int8(at::Tensor& input, auto inp_norm = at::empty_like(input_cont); auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm); - // computing the blocking across K dimension - // launch_residual_layer_norm((T*)inp_norm.data_ptr(), - // (T*)residual_add.data_ptr(), - // (T*)input_cont.data_ptr(), - // (T*)residual.data_ptr(), - // (T*)input_bias.data_ptr(), - // (T*)gamma.data_ptr(), - // (T*)beta.data_ptr(), - // epsilon, - // bsz, - // input_cont.size(2), - // preLayerNorm, - // Context::Instance().GetCurrentStream()); - quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); launch_bias_gelu((T*)output.data_ptr(), (T*)bias.data_ptr(), diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index c2bb30126cd6..47a10771d564 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -100,6 +100,14 @@ void launch_dequantize(T* output, cudaStream_t stream); template +void launch_dequantize(T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + cudaStream_t stream); +template void launch_gptj_residual_add(T* input, T* output, T* attn, @@ -158,3 +166,22 @@ void launch_bias_add_transform_0213(T* outputs, bool rotate_every_two, cudaStream_t stream, int trans_count); + +void run_gemm(void* A, + void* B, + void* C, + void* a, + void* aa, + int M, + int N, + int K, + int groups, + int groups1, + cudaStream_t stream); + +void launch_me(int8_t* output, + float* scales, + __half* input, + int intermediate_size, + int batch_size, + cudaStream_t stream); diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index db9efb19dcb1..50baeaade2a6 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -336,7 +336,7 @@ def _apply_injection_policy(self, ep_group=self.ep_group, expert_mp_group=self.expert_mp_group, config=self.config, - fp16=(self.dtype == torch.half), + fp16=(self.dtype == torch.half) or (self.dtype == torch.int8), training=False, return_tuple=return_tuple, quantize=(self.dtype == torch.int8), @@ -446,7 +446,7 @@ def _choose_module_key(self, sd): return 'model' def _convert_to_dtype(self): - if self.dtype is torch.int8 and self.quantization_scales is None: + if False: #self.dtype is torch.int8 and self.quantization_scales is None: quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping) model, self.quantization_scales = quantizer.model_quantize(self.module, self.injection_dict, diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index e0f44675dfd7..c7001c32059c 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -5,7 +5,12 @@ import torch -def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0): +def load_model_with_checkpoint(r_module, + sd, + mp_replace, + ckpt_type, + weight_quantizer=None, + rank=0): error_msgs = [] def transpose(data): @@ -15,7 +20,7 @@ def transpose(data): return data.reshape(data.shape[-1], data.shape[-2]) def load(module, prefix): - args = (sd, prefix, {}, True, [], [], error_msgs) + args = (sd[0], prefix, {}, True, [], [], error_msgs) if len(list(module.parameters())) > 0 and list( module.parameters())[0].numel() == 0: @@ -25,9 +30,9 @@ def load(module, prefix): else: if hasattr(module, 'weight'): module.weight = mp_replace.copy(module.weight.data, - sd[prefix + 'weight']) - if prefix + 'bias' in sd.keys(): - module.bias = mp_replace.copy(module.bias.data, sd[prefix + 'bias']) + sd[0][prefix + 'weight']) + if prefix + 'bias' in sd[0].keys(): + module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias']) def load_transformer_layer(module, prefix): if ckpt_type == "tp": @@ -35,71 +40,90 @@ def load_transformer_layer(module, prefix): def load_parameters(module, prefix): for n, p in module.named_parameters(): if len(n.split('.')) == 1: - src_shape = sd[prefix + n].shape + src_shape = sd[0][prefix + n].shape dst_shape = p.shape - if (len(src_shape) == 2 and len(dst_shape) == 2): if src_shape[0] == dst_shape[0] and src_shape[ 1] == dst_shape[1]: - p.data.copy_(sd[prefix + n]) + p = weight_quantizer.quantize( + transpose(sd[0][prefix + n]) if weight_quantizer. + q_int8 else sd[0][prefix + n]) + setattr(module, n, p) else: - if src_shape[0] != dst_shape[0]: - weight_split = torch.split( - sd[prefix + n], - dst_shape[0], - dim=0)[rank].to( - torch.cuda.current_device()).contiguous() + dim = 0 if src_shape[0] != dst_shape[0] else 1 + if src_shape[dim] > dst_shape[dim]: + weight_partition = torch.split(sd[0][prefix + n], + dst_shape[0], + dim=dim)[rank] + + p.data.copy_(weight_partition.contiguous()) else: - weight_split = torch.split( - sd[prefix + n], - dst_shape[1], - dim=1)[rank].to( - torch.cuda.current_device()).contiguous() - p.data.copy_(weight_split.contiguous()) + weight_partition = torch.cat([ + sd[j][prefix + n].to(torch.cuda.current_device()) + for j in range(len(sd)) + ], + dim=dim) + + weight_partition = transpose( + weight_partition + ) if weight_quantizer.q_int8 else weight_partition + setattr( + module, + n, + weight_quantizer.quantize( + weight_partition.to( + torch.cuda.current_device()))) else: if src_shape[0] == dst_shape[0]: - p.data.copy_(sd[prefix + n]) + p.data.copy_(sd[0][prefix + n]) else: - bias_split = torch.split( - sd[prefix + n], - dst_shape[-1])[rank].to( - torch.cuda.current_device()).contiguous() - p.data.copy_(bias_split) + if src_shape[0] > dst_shape[0]: + bias_split = torch.split( + sd[0][prefix + n], + dst_shape[-1])[rank].to( + torch.cuda.current_device()).contiguous() + p.data.copy_(bias_split) + else: + p.data.copy_( + torch.cat( + [sd[j][prefix + n] for j in range(len(sd))], + dim=0).to(torch.cuda.current_device()). + contiguous()) load_parameters(module, prefix) for n, child in module.named_children(): load_parameters(child, prefix + n + '.') else: - module.norm_w.data.copy_(sd[prefix + 'input_layernorm.' + 'weight']) - module.norm_b.data.copy_(sd[prefix + 'input_layernorm.' + 'bias']) - module.attention.attn_qkvw = mp_replace.copy( - module.attention.attn_qkvw.data, - transpose(sd[prefix + 'self_attention.query_key_value.' + 'weight'])) + module.norm_w.data.copy_(sd[0][prefix + 'input_layernorm.' + 'weight']) + module.norm_b.data.copy_(sd[0][prefix + 'input_layernorm.' + 'bias']) + module.attention.attn_qkvw = mp_replace.copy(module.attention.attn_qkvw, + weight_quantizer.quantize(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']) if weight_quantizer.q_int8 else \ + weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']))) module.attention.attn_qkvb = mp_replace.copy( module.attention.attn_qkvb.data, - sd[prefix + 'self_attention.query_key_value.' + 'bias']) - module.attention.attn_ow = mp_replace.copy( - module.attention.attn_ow.data, - transpose(sd[prefix + 'self_attention.dense.' + 'weight'])) + sd[0][prefix + 'self_attention.query_key_value.' + 'bias']) + module.attention.attn_ow = mp_replace.copy(module.attention.attn_ow, + weight_quantizer.quantize(sd[0][prefix + 'self_attention.dense.' + 'weight']) if weight_quantizer.q_int8 else \ + weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.dense.' + 'weight']))) module.attention.attn_ob = mp_replace.copy( module.attention.attn_ob.data, - sd[prefix + 'self_attention.dense.' + 'bias']) - module.mlp.attn_nw.data.copy_(sd[prefix + 'post_attention_layernorm.' + - 'weight']) - module.mlp.attn_nb.data.copy_(sd[prefix + 'post_attention_layernorm.' + - 'bias']) - module.mlp.inter_w = mp_replace.copy( - module.mlp.inter_w.data, - transpose(sd[prefix + 'mlp.dense_h_to_4h.' + 'weight'])) + sd[0][prefix + 'self_attention.dense.' + 'bias']) + module.mlp.attn_nw.data.copy_(sd[0][prefix + 'post_attention_layernorm.' + + 'weight']) + module.mlp.attn_nb.data.copy_(sd[0][prefix + 'post_attention_layernorm.' + + 'bias']) + module.mlp.inter_w = mp_replace.copy(module.mlp.inter_w, + weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']) if weight_quantizer.q_int8 else \ + weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']))) module.mlp.inter_b = mp_replace.copy( module.mlp.inter_b.data, - sd[prefix + 'mlp.dense_h_to_4h.' + 'bias']) - module.mlp.output_w = mp_replace.copy( - module.mlp.output_w.data, - transpose(sd[prefix + 'mlp.dense_4h_to_h.' + 'weight'])) + sd[0][prefix + 'mlp.dense_h_to_4h.' + 'bias']) + module.mlp.output_w = mp_replace.copy(module.mlp.output_w, + weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']) if weight_quantizer.q_int8 else \ + weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']))) module.mlp.output_b = mp_replace.copy( module.mlp.output_b.data, - sd[prefix + 'mlp.dense_4h_to_h.' + 'bias']) + sd[0][prefix + 'mlp.dense_4h_to_h.' + 'bias']) layer_policies = { nn.Linear: load, @@ -117,7 +141,7 @@ def load_module_recursive(module, prefix='', level=0): for name, child in module.named_children(): if child.__class__ in layer_policies: checking_key = prefix + name + '.' - if not any(checking_key in item for item in sd.keys()): + if not any(checking_key in item for item in sd[0].keys()): if hasattr(child, 'weight') and \ (hasattr(child.weight, 'ds_id') and \ child.weight.ds_id in all_ds_ids): @@ -168,6 +192,6 @@ def load_module_recursive(module, prefix='', level=0): embedding_weight = p assert hasattr(r_module, 'lm_head'), "attempting to set lm_head but it doesn't exist" r_module.lm_head.weight = embedding_weight - - del sd + for sd_ in sd: + del sd_ sd = None diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 011b66858710..01be32a56566 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -5,7 +5,7 @@ import deepspeed.ops.transformer as transformer_inference from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, BLOOMLayerPolicy from .replace_policy import replace_policies -from ..runtime.weight_quantizer import WeightQuantization +#from ..runtime.weight_quantizer import WeightQuantization from deepspeed import comm as dist from torch import nn @@ -115,8 +115,10 @@ def copy(self, dst, src): dst_shape[-1])[self.gpu_index].to( torch.cuda.current_device()).contiguous() dst.data.copy_(bias_split) - - return torch.nn.parameter.Parameter(dst, requires_grad=False) + dst = torch.nn.parameter.Parameter(dst, requires_grad=False) + if hasattr(src, 'scale'): + dst.scale = src.scale + return dst def get_transformer_name(replaced_module): @@ -134,6 +136,31 @@ def get_transformer_name(replaced_module): return transformer_name +class GroupQuantizer: + def __init__(self, q_int8=True, num_groups=32, group_size=32, num_bits=8): + self.num_groups = num_groups + self.group_size = group_size + self.num_bits = num_bits + self.q_int8 = q_int8 + + def quantize(self, inputs, qkv=True, count=1): + if not self.q_int8 or not qkv: + inputs = torch.nn.Parameter(inputs, requires_grad=False) + inputs.scale = torch.empty(1) + return inputs + q_range = 2**self.num_bits + inputs = inputs.to(torch.cuda.current_device()) + input_flat = inputs.reshape(self.num_groups, -1).contiguous() + input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float() + input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float() + scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range) + input_flat = (input_flat / scale).round().clamp(-q_range // 2, q_range // 2 - 1) + inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous() + out = torch.nn.Parameter(inputs_q, requires_grad=False) + out.scale = scale + return out + + def replace_transformer_layer(orig_layer_impl, model, policy=None, @@ -231,7 +258,7 @@ def replace_with_policy(child, _res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type) attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm() - if quantize: + if False: if policy_cls is not HFBertLayerPolicy: qkvw = qkvw.to(torch.int8) dense_w = dense_w.to(torch.int8) @@ -334,21 +361,21 @@ def replace_with_policy(child, new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, - quantize_scales=quantization_scales[layer_id], + #quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) - if quantize and qkvw.dtype != torch.int8: - quantize_bits = 8 - quantizer = WeightQuantization() - if policy_cls is HFBertLayerPolicy: - data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3) - else: - data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups) - qkvw.data.copy_(data_quantized) - qkvw.data = qkvw.data.to(torch.int8) + #if quantize and qkvw.dtype != torch.int8: + # quantize_bits = 8 + # quantizer = WeightQuantization() + # if policy_cls is HFBertLayerPolicy: + # data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3) + # else: + # data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups) + # qkvw.data.copy_(data_quantized) + # qkvw.data = qkvw.data.to(torch.int8) else: if moe: @@ -783,6 +810,7 @@ def replace_fn(child, _policy, layer_id=0): replace_fn=replace_fn, _replace_policy=policy) + quantizer = GroupQuantizer(q_int8=quantize) if checkpoint_dict is not None: start_time = time.time() rank = dist.get_rank() if dist.is_initialized() else 0 @@ -792,36 +820,70 @@ def replace_fn(child, _policy, layer_id=0): ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size) base_dir = checkpoint_dict.get('base_dir', '') - if ckpt_type == 'pp': + if ckpt_type == 'pp' and type(checkpoint) is list: pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") + for i in range(len(checkpoint)): - if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: - pbar.update(1) + sd = torch.load(checkpoint[i], map_location='cpu') - load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type) + load_model_with_checkpoint( + replaced_module, + sd, + mp_replace, + ckpt_type, + quantizer, + ) else: - num_checkpoints = len(checkpoint) // ckpt_mp_size - assert world_size >= ckpt_mp_size,\ - "Currently, merging checkpoints is not supported (when world_size is smaller than #checkpoints)!" - checkpoint_stride = world_size // ckpt_mp_size - if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: - pbar = tqdm.tqdm(total=num_checkpoints, - desc=f"Loading {num_checkpoints} checkpoint shards") - for i in range(num_checkpoints): + if "tp" in checkpoint: + num_checkpoints = len(checkpoint["tp"]) // ckpt_mp_size + sd_offset = int(rank / (world_size / ckpt_mp_size)) + sd_count = int((rank + 1) / (world_size / ckpt_mp_size)) - sd_offset if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: - pbar.update(1) - - ckpt_index = i * ckpt_mp_size + (rank // checkpoint_stride) - ckpt_file = os.path.join( - base_dir, - checkpoint[ckpt_index]) if base_dir else checkpoint[ckpt_index] - sd = torch.load(ckpt_file, map_location='cpu') - load_model_with_checkpoint(replaced_module, - sd, - mp_replace, - ckpt_type, - rank % (world_size // ckpt_mp_size)) + pbar = tqdm.tqdm(total=num_checkpoints, + desc=f"Loading {num_checkpoints} checkpoint shards") + for i in range(num_checkpoints): + if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank( + ) == 0: + pbar.update(1) + ckpt_index = i * ckpt_mp_size + sd_offset + ckpt_files = [ + os.path.join(base_dir, + checkpoint["tp"][ckpt_index + j]) + if base_dir else checkpoint["tp"][ckpt_index + j] + for j in range(sd_count) + ] + + sds = [ + torch.load(ckpt_file, + map_location='cpu') for ckpt_file in ckpt_files + ] + load_model_with_checkpoint(replaced_module, + sds, + mp_replace, + ckpt_type, + quantizer, + int(rank % (world_size / ckpt_mp_size))) + + if "non_tp" in checkpoint: + pbar = tqdm.tqdm( + total=len(checkpoint["non_tp"]), + desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards") + + for i in range(len(checkpoint["non_tp"])): + if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank( + ) == 0: + pbar.update(1) + ckpt_file = os.path.join( + base_dir, + checkpoint["non_tp"][i]) if base_dir else checkpoint["non_tp"][i] + sds = [torch.load(ckpt_file, map_location='cpu')] + load_model_with_checkpoint(replaced_module, + sds, + mp_replace, + ckpt_type, + quantizer, + int(rank % (world_size / ckpt_mp_size))) print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") if save_mp_checkpoint_path is not None: @@ -833,7 +895,7 @@ def replace_fn(child, _policy, layer_id=0): dist.barrier() transformer_name = get_transformer_name(replaced_module) non_tp_ckpt_name = f'{ckpt_name}-non-tp.pt' - ckpt_files = [non_tp_ckpt_name] * world_size + ckpt_files = [non_tp_ckpt_name] #* world_size if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") torch.save( @@ -844,11 +906,14 @@ def replace_fn(child, _policy, layer_id=0): if transformer_name not in k }), f'{save_mp_checkpoint_path}/{non_tp_ckpt_name}') - ckpt_files += [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)] + #ckpt_files += [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)] config = json.dumps({ 'type': ckpt_name, 'base_dir': f'{save_mp_checkpoint_path}', - 'checkpoints': ckpt_files, + 'checkpoints': { + "non_tp": ckpt_files, + "tp": [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)] + }, 'version': 1.0, 'parallelization': 'tp', 'mp_size': world_size diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index d38cf8c3d395..d90be6682b4c 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -412,18 +412,25 @@ def selfAttention_fp(): else: qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \ inference_cuda_module.qkv_gemm_fp32 + qkv_out = qkv_func( input, attn_qkvw, + attn_qkvw.scale, (attn_qkvb if attn_qkvb is not None else norm_b), norm_w, norm_b, config.epsilon, (attn_qkvb is not None), 1 if config.bigscience_bloom else - DeepSpeedTransformerInference.layer_id) + DeepSpeedTransformerInference.layer_id, + config.q_int8) context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask) - output = vector_matmul_func(context_layer, attn_ow, False) + output = vector_matmul_func(context_layer, + attn_ow, + False, + attn_ow.scale, + config.q_int8) return output, key_layer, value_layer, context_layer, qkv_out[-1] @@ -455,7 +462,7 @@ def selfAttention_int8(): (merge_count)) return output, key_layer, value_layer, context_layer - if config.q_int8: + if False: #config.q_int8: output, key_layer, value_layer, context_layer = selfAttention_int8() else: output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp() @@ -483,30 +490,34 @@ def __init__(self, qkv_merging=False): super(DeepSpeedSelfAttention, self).__init__() self.config = config - data_type = torch.half if config.fp16 else torch.float + data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float + data_type_fp = torch.half if config.fp16 else torch.float self.config.layer_id = DeepSpeedSelfAttention.num_layers DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' - self.attn_qkvw = nn.Parameter( - torch.empty(self.config.hidden_size, - (self.config.hidden_size // self.config.mp_size) * 3, - dtype=data_type, - device=device)) - self.attn_qkvb = nn.Parameter( - torch.empty((self.config.hidden_size // self.config.mp_size) * 3, - dtype=data_type, - device=device)) - - self.attn_ow = nn.Parameter( - torch.empty(self.config.hidden_size // self.config.mp_size, - self.config.hidden_size, - dtype=data_type, - device=device)) - - self.attn_ob = nn.Parameter( - torch.empty(self.config.hidden_size, - dtype=data_type, - device=device)) + self.attn_qkvw = nn.Parameter(torch.empty( + self.config.hidden_size, + (self.config.hidden_size // self.config.mp_size) * 3, + dtype=data_type, + device=device), + requires_grad=False) + self.attn_qkvb = nn.Parameter(torch.empty( + (self.config.hidden_size // self.config.mp_size) * 3, + dtype=data_type_fp, + device=device), + requires_grad=False) + + self.attn_ow = nn.Parameter(torch.empty(self.config.hidden_size // + self.config.mp_size, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + + self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type_fp, + device=device), + requires_grad=False) self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size @@ -591,36 +602,16 @@ def forward(ctx, vector_matmul_func, bias_residual_func): - if config.q_int8: - (intermediate, - residual_add) = inference_cuda_module.mlp_gemm_int8( - input, - residual, - bias, - inter_w, - inter_b, - attn_nw, - attn_nb, - config.epsilon, - q_scales[2], - (q_groups * (2**merge_count)), - config.pre_layer_norm) - output = inference_cuda_module.vector_matmul_int8(intermediate, - output_w, - q_scales[3], - q_groups, - (merge_count)) + if attn_nw is None: + output = fused_gemm_gelu(residual_norm, + inter_w, + inter_b, + output_w, + config.epsilon, + config.pre_layer_norm, + False) else: - if attn_nw is None: - output = fused_gemm_gelu(residual_norm, - inter_w, - inter_b, - output_w, - config.epsilon, - config.pre_layer_norm, - False) - else: - intermediate, residual_add = mlp_gemm_func(input, + intermediate, residual_add = mlp_gemm_func(input, residual, bias, inter_w, @@ -629,9 +620,15 @@ def forward(ctx, attn_nb, config.epsilon, config.pre_layer_norm, - config.mlp_after_attn) - output = vector_matmul_func(intermediate, output_w, False) - + config.mlp_after_attn, + inter_w.scale, + config.q_int8) + output = vector_matmul_func(intermediate, + output_w, + False, + output_w.scale, + config.q_int8) + #print(output) inference_cuda_module.residual_add( output, residual if config.pre_layer_norm else residual_add, @@ -663,34 +660,38 @@ def __init__(self, super(DeepSpeedMLP, self).__init__() self.config = config - data_type = torch.half if config.fp16 else torch.float + data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float + data_type_fp = torch.half if config.fp16 else torch.float device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' - self.attn_nw = nn.Parameter( - torch.empty(self.config.hidden_size, - dtype=data_type, - device=device)) - self.attn_nb = nn.Parameter( - torch.empty(self.config.hidden_size, - dtype=data_type, - device=device)) - self.inter_w = nn.Parameter( - torch.empty(self.config.hidden_size, - self.config.intermediate_size // self.config.mp_size, - dtype=data_type, - device=device)) - self.inter_b = nn.Parameter( - torch.empty(self.config.intermediate_size // self.config.mp_size, - dtype=data_type, - device=device)) - self.output_w = nn.Parameter( - torch.empty((self.config.intermediate_size // self.config.mp_size), - self.config.hidden_size, - dtype=data_type, - device=device)) - self.output_b = nn.Parameter( - torch.empty(self.config.hidden_size, - dtype=data_type, - device=device)) + self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type_fp, + device=device), + requires_grad=False) + self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type_fp, + device=device), + requires_grad=False) + self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size, + self.config.intermediate_size // + self.config.mp_size, + dtype=data_type, + device=device), + requires_grad=False) + self.inter_b = nn.Parameter(torch.empty(self.config.intermediate_size // + self.config.mp_size, + dtype=data_type_fp, + device=device), + requires_grad=False) + self.output_w = nn.Parameter(torch.empty( + (self.config.intermediate_size // self.config.mp_size), + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type_fp, + device=device), + requires_grad=False) # used for quantization self.q_scales = q_scales @@ -785,14 +786,14 @@ def __init__(self, mlp_extra_grouping) device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' - self.norm_w = nn.Parameter( - torch.empty(self.config.hidden_size, - dtype=data_type, - device=device)) - self.norm_b = nn.Parameter( - torch.empty(self.config.hidden_size, - dtype=data_type, - device=device)) + self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) self.layer_past = None def forward(self, From d98f1f9be1715b12963a905df87546fb5333dbff Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Mon, 15 Aug 2022 07:37:59 +0500 Subject: [PATCH 03/16] fixing some issue with loading checkpoint and bias-add --- .../transformer/inference/csrc/pt_binding.cpp | 2 +- deepspeed/module_inject/load_checkpoint.py | 2 - deepspeed/module_inject/replace_module.py | 58 +++++++++---------- 3 files changed, 29 insertions(+), 33 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 3f1d5b935abd..f30dea8e3a28 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -621,7 +621,7 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, if (add_bias) launch_bias_add((T*)output.data_ptr(), (T*)bias.data_ptr(), - weight.size(1), + q_int8 ? weight.size(0) : weight.size(1), bsz, Context::Instance().GetCurrentStream()); return torch::from_blob(workspace, input.sizes(), input.options()); diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index c7001c32059c..a73a0573a8bd 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -55,8 +55,6 @@ def load_parameters(module, prefix): weight_partition = torch.split(sd[0][prefix + n], dst_shape[0], dim=dim)[rank] - - p.data.copy_(weight_partition.contiguous()) else: weight_partition = torch.cat([ sd[j][prefix + n].to(torch.cuda.current_device()) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 01be32a56566..40fda36c60a6 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -816,8 +816,9 @@ def replace_fn(child, _policy, layer_id=0): rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 checkpoint = checkpoint_dict['checkpoints'] + ckpt_list = checkpoint["tp"] if type(checkpoint) is dict else checkpoint ckpt_type = checkpoint_dict.get('parallelization', 'pp') - ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size) + ckpt_mp_size = checkpoint_dict.get('mp_size', len(ckpt_list)) base_dir = checkpoint_dict.get('base_dir', '') if ckpt_type == 'pp' and type(checkpoint) is list: @@ -826,7 +827,7 @@ def replace_fn(child, _policy, layer_id=0): for i in range(len(checkpoint)): - sd = torch.load(checkpoint[i], map_location='cpu') + sd = [torch.load(checkpoint[i], map_location='cpu')] load_model_with_checkpoint( replaced_module, sd, @@ -835,35 +836,32 @@ def replace_fn(child, _policy, layer_id=0): quantizer, ) else: - if "tp" in checkpoint: - num_checkpoints = len(checkpoint["tp"]) // ckpt_mp_size - sd_offset = int(rank / (world_size / ckpt_mp_size)) - sd_count = int((rank + 1) / (world_size / ckpt_mp_size)) - sd_offset + num_checkpoints = len(ckpt_list) // ckpt_mp_size + sd_offset = int(rank / (world_size / ckpt_mp_size)) + sd_count = int((rank + 1) / (world_size / ckpt_mp_size)) - sd_offset + if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: + pbar = tqdm.tqdm(total=num_checkpoints, + desc=f"Loading {num_checkpoints} checkpoint shards") + for i in range(num_checkpoints): if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: - pbar = tqdm.tqdm(total=num_checkpoints, - desc=f"Loading {num_checkpoints} checkpoint shards") - for i in range(num_checkpoints): - if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank( - ) == 0: - pbar.update(1) - ckpt_index = i * ckpt_mp_size + sd_offset - ckpt_files = [ - os.path.join(base_dir, - checkpoint["tp"][ckpt_index + j]) - if base_dir else checkpoint["tp"][ckpt_index + j] - for j in range(sd_count) - ] - - sds = [ - torch.load(ckpt_file, - map_location='cpu') for ckpt_file in ckpt_files - ] - load_model_with_checkpoint(replaced_module, - sds, - mp_replace, - ckpt_type, - quantizer, - int(rank % (world_size / ckpt_mp_size))) + pbar.update(1) + ckpt_index = i * ckpt_mp_size + sd_offset + ckpt_files = [ + os.path.join(base_dir, + ckpt_list[ckpt_index + + j]) if base_dir else ckpt_list[ckpt_index + j] + for j in range(sd_count) + ] + sds = [ + torch.load(ckpt_file, + map_location='cpu') for ckpt_file in ckpt_files + ] + load_model_with_checkpoint(replaced_module, + sds, + mp_replace, + ckpt_type, + quantizer, + int(rank % (world_size / ckpt_mp_size))) if "non_tp" in checkpoint: pbar = tqdm.tqdm( From ebc82bb0303906bbe4eff1d4289cdc826b191c3b Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 16 Aug 2022 02:23:35 +0500 Subject: [PATCH 04/16] adding the logic to store/restore scale for INT8 checkpoint --- deepspeed/module_inject/load_checkpoint.py | 54 +++++++++++-------- deepspeed/module_inject/replace_module.py | 11 +++- .../inference/transformer_inference.py | 1 - 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index a73a0573a8bd..4601fbb92223 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -40,21 +40,35 @@ def load_transformer_layer(module, prefix): def load_parameters(module, prefix): for n, p in module.named_parameters(): if len(n.split('.')) == 1: - src_shape = sd[0][prefix + n].shape + if type(sd[0][prefix + n]) is list: + tmp_data, scale = sd[0][prefix + n] + scale = scale.to(torch.cuda.current_device()) + else: + tmp_data = sd[0][prefix + n] + scale = None + src_shape = tmp_data.shape dst_shape = p.shape + inner_dim = 1 if tmp_data.dtype == torch.int8 else 0 + outer_dim = 0 if tmp_data.dtype == torch.int8 else 1 if (len(src_shape) == 2 and len(dst_shape) == 2): - if src_shape[0] == dst_shape[0] and src_shape[ - 1] == dst_shape[1]: - p = weight_quantizer.quantize( - transpose(sd[0][prefix + n]) if weight_quantizer. - q_int8 else sd[0][prefix + n]) + if (src_shape[inner_dim] == dst_shape[0] + and src_shape[outer_dim] == dst_shape[1]): + if tmp_data.dtype != torch.int8: + p = weight_quantizer.quantize( + transpose(tmp_data) if weight_quantizer. + q_int8 else tmp_data) + else: + p = tmp_data + p.scale = scale setattr(module, n, p) else: - dim = 0 if src_shape[0] != dst_shape[0] else 1 + dim = inner_dim if src_shape[inner_dim] != dst_shape[ + 0] else outer_dim if src_shape[dim] > dst_shape[dim]: - weight_partition = torch.split(sd[0][prefix + n], - dst_shape[0], - dim=dim)[rank] + weight_partition = torch.split( + tmp_data, + dst_shape[0], + dim=dim)[rank].to(torch.cuda.current_device()) else: weight_partition = torch.cat([ sd[j][prefix + n].to(torch.cuda.current_device()) @@ -62,22 +76,20 @@ def load_parameters(module, prefix): ], dim=dim) - weight_partition = transpose( - weight_partition - ) if weight_quantizer.q_int8 else weight_partition - setattr( - module, - n, - weight_quantizer.quantize( - weight_partition.to( - torch.cuda.current_device()))) + if tmp_data.dtype != torch.int8: + weight_partition = weight_quantizer.quantize( + transpose(weight_partition) if weight_quantizer. + q_int8 else weight_partition) + else: + weight_partition.scale = scale + setattr(module, n, weight_partition) else: if src_shape[0] == dst_shape[0]: - p.data.copy_(sd[0][prefix + n]) + p.data.copy_(tmp_data) else: if src_shape[0] > dst_shape[0]: bias_split = torch.split( - sd[0][prefix + n], + tmp_data, dst_shape[-1])[rank].to( torch.cuda.current_device()).contiguous() p.data.copy_(bias_split) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 40fda36c60a6..6b5ab302acc4 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -919,11 +919,18 @@ def replace_fn(child, _policy, layer_id=0): with open(f"{save_mp_checkpoint_path}/{ckpt_name}_ds-inference_config.json", "w") as cfg: cfg.write(config) + + rep_sd = replaced_module.state_dict() + for n, p in replaced_module.named_parameters(): + if hasattr(p, 'scale'): + rep_sd[n] = [p, p.scale] torch.save( OrderedDict({ - k: v + k: [v, + v.scale] if hasattr(v, + 'scale') else v for k, - v in dict(replaced_module.state_dict()).items() if transformer_name in k + v in dict(rep_sd).items() if transformer_name in k }), f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{rank:0>2d}.pt') diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index d90be6682b4c..31505ff6bc09 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -412,7 +412,6 @@ def selfAttention_fp(): else: qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \ inference_cuda_module.qkv_gemm_fp32 - qkv_out = qkv_func( input, attn_qkvw, From 43a7023078b15cbecb71742263869ddff1476bcc Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 16 Aug 2022 02:33:19 +0500 Subject: [PATCH 05/16] add empty quantization scale for different models to run with fp16 --- deepspeed/module_inject/replace_module.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 6b5ab302acc4..889876783427 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -510,10 +510,14 @@ def _transpose(x): attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) else: - attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw) - attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb) - - attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) + attn_block.attn_qkvw = quantizer.quantize( + mp_replace.qkv_copy(attn_block.attn_qkvw, + qkvw)) + attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb, qkvb) + + attn_block.attn_ow = quantizer.quantize( + mp_replace.copy(attn_block.attn_ow, + dense_w)) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) if moe: @@ -569,9 +573,13 @@ def _transpose(x): mpl_block.output_b, _4hh_b) else: - mpl_block.inter_w = mp_replace.copy(mpl_block.inter_w, _h4h_w) + mpl_block.inter_w = quantizer.quantize( + mp_replace.copy(mpl_block.inter_w, + _h4h_w)) mpl_block.inter_b = mp_replace.copy(mpl_block.inter_b, _h4h_b) - mpl_block.output_w = mp_replace.copy(mpl_block.output_w, _4hh_w) + mpl_block.output_w = quantizer.quantize( + mp_replace.copy(mpl_block.output_w, + _4hh_w)) mpl_block.output_b = mp_replace.copy(mpl_block.output_b, _4hh_b) if attn_nw is None: From 00aa18885b9eecf82860e37d0ea492b8b095f2bb Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 16 Aug 2022 02:45:41 +0500 Subject: [PATCH 06/16] Empty-Commit From 84e0d03bc5ef9339d5bac6d8e784c50aaf098377 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Thu, 18 Aug 2022 22:58:52 +0500 Subject: [PATCH 07/16] fix sevral issues after merging with master --- csrc/transformer/inference/csrc/pt_binding.cpp | 10 +++------- deepspeed/module_inject/replace_module.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 5c27036bb982..5265bfb3e196 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -939,11 +939,6 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, if (q_int8) { quantized_gemm(output, (T*)inp_norm.data_ptr(), weight, q_scale, q_scale.size(0), bsz); - launch_bias_gelu((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(0), - bsz, - Context::Instance().GetCurrentStream()); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -965,16 +960,17 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif + } if (act_func_type == ActivationFuncType::GELU) { launch_bias_gelu((T*)output.data_ptr(), (T*)bias.data_ptr(), - weight.size(1), + q_int8 ? weight.size(0) : weight.size(1), bsz, Context::Instance().GetCurrentStream()); } else if (act_func_type == ActivationFuncType::ReLU) { launch_bias_relu((T*)output.data_ptr(), (T*)bias.data_ptr(), - weight.size(1), + q_int8 ? weight.size(0) : weight.size(1), bsz, Context::Instance().GetCurrentStream()); } diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 2971589d4884..2044bd56ff94 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -506,9 +506,11 @@ def _transpose(x): attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) else: attn_block.attn_qkvw = quantizer.quantize( - mp_replace.qkv_copy(attn_block.attn_qkvw, - qkvw)) - attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb, qkvb) + mp_replace.copy(attn_block.attn_qkvw, qkvw) if bigscience_bloom else \ + mp_replace.qkv_copy(attn_block.attn_qkvw, qkvw)) + attn_block.attn_qkvb = \ + mp_replace.copy(attn_block.attn_qkvb, qkvb) if bigscience_bloom else \ + mp_replace.qkv_copy(attn_block.attn_qkvb, qkvb) attn_block.attn_ow = quantizer.quantize( mp_replace.copy(attn_block.attn_ow, @@ -810,10 +812,10 @@ def replace_fn(child, _policy, layer_id=0): _replace_policy=policy) quantizer = GroupQuantizer(q_int8=quantize) + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 if checkpoint_dict is not None: start_time = time.time() - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 checkpoint = checkpoint_dict['checkpoints'] ckpt_list = checkpoint["tp"] if type(checkpoint) is dict else checkpoint ckpt_type = checkpoint_dict.get('parallelization', 'pp') @@ -887,7 +889,7 @@ def replace_fn(child, _policy, layer_id=0): from collections import OrderedDict import json - ckpt_name = checkpoint_dict['type'] + ckpt_name = checkpoint_dict.get('type', 'ds_model') if dist.is_initialized(): dist.barrier() transformer_name = get_transformer_name(replaced_module) From f6cb028d1fb4a3bceacf1030cd70bef4fb4a68c9 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sat, 20 Aug 2022 03:47:47 +0500 Subject: [PATCH 08/16] several fixes for generating the INT8 sharded checkpoint --- deepspeed/__init__.py | 6 +- deepspeed/inference/engine.py | 64 +++++++------- deepspeed/module_inject/load_checkpoint.py | 45 +++++++--- deepspeed/module_inject/replace_module.py | 98 ++++++++++++++-------- 4 files changed, 138 insertions(+), 75 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 1a5ee744d9bd..0c5158562126 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -242,7 +242,8 @@ def init_inference(model, moe_type='standard', args=None, enable_cuda_graph=False, - save_mp_checkpoint_path=None): + save_mp_checkpoint_path=None, + base_dir=""): """Initialize the DeepSpeed InferenceEngine. Arguments: @@ -309,6 +310,7 @@ def init_inference(model, moe_type, args, enable_cuda_graph, - save_mp_checkpoint_path) + save_mp_checkpoint_path, + base_dir) return engine diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 9844e70330f4..427ad5d4f9c6 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -51,7 +51,8 @@ def __init__(self, moe_type='standard', config=None, enable_cuda_graph=False, - save_mp_checkpoint_path=None): + save_mp_checkpoint_path=None, + base_dir=""): """ Args: model: torch.nn.Module @@ -135,7 +136,8 @@ def __init__(self, moe_type, training_mp_size, self.checkpoint if replace_with_kernel_inject else None, - save_mp_checkpoint_path=save_mp_checkpoint_path) + save_mp_checkpoint_path=save_mp_checkpoint_path, + base_dir=base_dir) elif replace_method == 'auto': self._apply_injection_policy( return_tuple=return_tuple, @@ -145,7 +147,8 @@ def __init__(self, moe_type=moe_type, training_mp_size=training_mp_size, checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None, - save_mp_checkpoint_path=save_mp_checkpoint_path) + save_mp_checkpoint_path=save_mp_checkpoint_path, + base_dir=base_dir) device = torch.cuda.current_device() self.module.to(device) @@ -326,36 +329,37 @@ def _apply_injection_policy(self, moe_type='standard', training_mp_size=1, checkpoint_dir=None, - save_mp_checkpoint_path=False): + save_mp_checkpoint_path=False, + base_dir=""): checkpoint = SDLoaderFactory.get_sd_loader_json( checkpoint_dir, self.checkpoint_engine) if checkpoint_dir is not None else None - replace_transformer_layer( - client_module, - self.module, - triangular_masking=self.triangular_masking, - policy=injection_policy, - mp_size=self.mp_world_size, - mp_group=self.mp_group, - ep_group=self.ep_group, - expert_mp_group=self.expert_mp_group, - config=self.config, - fp16=(self.dtype == torch.half) or (self.dtype == torch.int8), - training=False, - return_tuple=return_tuple, - quantize=(self.dtype == torch.int8), - quantize_settings=(self.quantization_scales, - self.quantize_merge_count, - self.mlp_extra_grouping, - self.quantize_groups), - replace_with_kernel_inject=replace_with_kernel_inject, - moe=moe, - moe_experts=moe_experts, - moe_type=moe_type, - training_mp_size=training_mp_size, - checkpoint_dict=checkpoint, - save_mp_checkpoint_path=save_mp_checkpoint_path, - ) + replace_transformer_layer(client_module, + self.module, + triangular_masking=self.triangular_masking, + policy=injection_policy, + mp_size=self.mp_world_size, + mp_group=self.mp_group, + ep_group=self.ep_group, + expert_mp_group=self.expert_mp_group, + config=self.config, + fp16=(self.dtype == torch.half) + or (self.dtype == torch.int8), + training=False, + return_tuple=return_tuple, + quantize=(self.dtype == torch.int8), + quantize_settings=(self.quantization_scales, + self.quantize_merge_count, + self.mlp_extra_grouping, + self.quantize_groups), + replace_with_kernel_inject=replace_with_kernel_inject, + moe=moe, + moe_experts=moe_experts, + moe_type=moe_type, + training_mp_size=training_mp_size, + checkpoint_dict=checkpoint, + save_mp_checkpoint_path=save_mp_checkpoint_path, + base_dir=base_dir) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index 4601fbb92223..c4cb02075a9a 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -39,7 +39,7 @@ def load_transformer_layer(module, prefix): def load_parameters(module, prefix): for n, p in module.named_parameters(): - if len(n.split('.')) == 1: + if prefix + n in sd[0] and len(n.split('.')) == 1: if type(sd[0][prefix + n]) is list: tmp_data, scale = sd[0][prefix + n] scale = scale.to(torch.cuda.current_device()) @@ -64,23 +64,48 @@ def load_parameters(module, prefix): else: dim = inner_dim if src_shape[inner_dim] != dst_shape[ 0] else outer_dim - if src_shape[dim] > dst_shape[dim]: + dim1 = 0 if src_shape[inner_dim] != dst_shape[0] else 1 + if src_shape[dim] > dst_shape[dim1]: weight_partition = torch.split( tmp_data, - dst_shape[0], + dst_shape[dim1], dim=dim)[rank].to(torch.cuda.current_device()) + assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \ + '''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\ + Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!''' + scale = scale.view( + -1)[weight_quantizer.num_groups * + (rank + 1):].reshape( + weight_quantizer.num_groups, + -1).contiguous() else: - weight_partition = torch.cat([ - sd[j][prefix + n].to(torch.cuda.current_device()) - for j in range(len(sd)) - ], - dim=dim) + assert tmp_data.dtype != torch.int8, \ + '''Merging of the checkpoints are not supported when using INT8 checkpoint! \ + Please use a as many GPUs as TP-size for the checkpoint''' + all_data = [ + sd[j][prefix + n] for j in range(len(sd)) + ] + weight_partition = torch.cat( + [(ad[0] if type(ad) is list else ad).to( + torch.cuda.current_device()) + for ad in all_data], + dim=dim) + if tmp_data.dtype == torch.int8: + scale = torch.cat([ + ad[1].to(torch.cuda.current_device()) + for ad in all_data + ], + dim=dim) if tmp_data.dtype != torch.int8: weight_partition = weight_quantizer.quantize( - transpose(weight_partition) if weight_quantizer. - q_int8 else weight_partition) + transpose(weight_partition), \ + parallel_dim=(0 if dim == 1 else 1)) if weight_quantizer.q_int8 else \ + weight_quantizer.quantize(weight_partition) else: + weight_partition = torch.nn.parameter.Parameter( + weight_partition, + requires_grad=False) weight_partition.scale = scale setattr(module, n, weight_partition) else: diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 2044bd56ff94..2a6ca53a1241 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -143,7 +143,7 @@ def __init__(self, q_int8=True, num_groups=32, group_size=32, num_bits=8): self.num_bits = num_bits self.q_int8 = q_int8 - def quantize(self, inputs, qkv=True, count=1): + def quantize(self, inputs, qkv=True, count=1, parallel_dim=0): if not self.q_int8 or not qkv: inputs = torch.nn.Parameter(inputs, requires_grad=False) inputs.scale = torch.empty(1) @@ -157,7 +157,33 @@ def quantize(self, inputs, qkv=True, count=1): input_flat = (input_flat / scale).round().clamp(-q_range // 2, q_range // 2 - 1) inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous() out = torch.nn.Parameter(inputs_q, requires_grad=False) - out.scale = scale + #print(inputs.shape) + inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim) + input_flat = [ + inputs_split[i].reshape(self.num_groups, + -1).contiguous() for i in range(2) + ] + input_min = [ + torch.min(input_flat[i], + dim=1, + keepdim=True)[0].float() for i in range(2) + ] + input_max = [ + torch.max(input_flat[i], + dim=1, + keepdim=True)[0].float() for i in range(2) + ] + scale1 = [ + (torch.max(input_min[i].abs(), + input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0) + for i in range(2) + ] + + out.scale = torch.cat([scale.squeeze().unsqueeze(0), + scale1[0], + scale1[1]], + dim=0).reshape(self.num_groups, + -1).contiguous() return out @@ -188,7 +214,8 @@ def replace_transformer_layer(orig_layer_impl, moe_experts=1, moe_type='standard', checkpoint_dict=None, - save_mp_checkpoint_path=None): + save_mp_checkpoint_path=None, + base_dir=""): """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, @@ -820,7 +847,7 @@ def replace_fn(child, _policy, layer_id=0): ckpt_list = checkpoint["tp"] if type(checkpoint) is dict else checkpoint ckpt_type = checkpoint_dict.get('parallelization', 'pp') ckpt_mp_size = checkpoint_dict.get('mp_size', len(ckpt_list)) - base_dir = checkpoint_dict.get('base_dir', '') + base_dir1 = checkpoint_dict.get('base_dir', base_dir) if ckpt_type == 'pp' and type(checkpoint) is list: pbar = tqdm.tqdm(total=len(checkpoint), @@ -838,19 +865,19 @@ def replace_fn(child, _policy, layer_id=0): ) else: num_checkpoints = len(ckpt_list) // ckpt_mp_size - sd_offset = int(rank / (world_size / ckpt_mp_size)) - sd_count = int((rank + 1) / (world_size / ckpt_mp_size)) - sd_offset - if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: - pbar = tqdm.tqdm(total=num_checkpoints, - desc=f"Loading {num_checkpoints} checkpoint shards") + tp_split_size = (world_size / ckpt_mp_size) + sd_offset = int(rank / tp_split_size) + sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset + pbar = tqdm.tqdm(total=num_checkpoints, + desc=f"Loading {num_checkpoints} checkpoint shards") for i in range(num_checkpoints): - if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: - pbar.update(1) + pbar.update(1) ckpt_index = i * ckpt_mp_size + sd_offset ckpt_files = [ - os.path.join(base_dir, + os.path.join(base_dir1, ckpt_list[ckpt_index + - j]) if base_dir else ckpt_list[ckpt_index + j] + j]) if base_dir1 else ckpt_list[ckpt_index + + j] for j in range(sd_count) ] sds = [ @@ -862,7 +889,7 @@ def replace_fn(child, _policy, layer_id=0): mp_replace, ckpt_type, quantizer, - int(rank % (world_size / ckpt_mp_size))) + int(rank % tp_split_size)) if "non_tp" in checkpoint: pbar = tqdm.tqdm( @@ -870,25 +897,23 @@ def replace_fn(child, _policy, layer_id=0): desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards") for i in range(len(checkpoint["non_tp"])): - if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank( - ) == 0: - pbar.update(1) - ckpt_file = os.path.join( - base_dir, - checkpoint["non_tp"][i]) if base_dir else checkpoint["non_tp"][i] + pbar.update(1) + ckpt_file = os.path.join(base_dir1, + checkpoint["non_tp"][i] + ) if base_dir1 else checkpoint["non_tp"][i] sds = [torch.load(ckpt_file, map_location='cpu')] load_model_with_checkpoint(replaced_module, sds, mp_replace, ckpt_type, quantizer, - int(rank % (world_size / ckpt_mp_size))) + int(rank % tp_split_size)) print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") if save_mp_checkpoint_path is not None: from collections import OrderedDict import json - + num_partitions = 8 ckpt_name = checkpoint_dict.get('type', 'ds_model') if dist.is_initialized(): dist.barrier() @@ -910,8 +935,12 @@ def replace_fn(child, _policy, layer_id=0): 'type': ckpt_name, 'base_dir': f'{save_mp_checkpoint_path}', 'checkpoints': { - "non_tp": ckpt_files, - "tp": [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)] + "non_tp": + ckpt_files, + "tp": [ + f'{ckpt_name}-tp_{r:0>2d}_{m:0>2d}.pt' + for m in range(num_partitions) for r in range(world_size) + ] }, 'version': 1.0, 'parallelization': 'tp', @@ -925,15 +954,18 @@ def replace_fn(child, _policy, layer_id=0): for n, p in replaced_module.named_parameters(): if hasattr(p, 'scale'): rep_sd[n] = [p, p.scale] - torch.save( - OrderedDict({ - k: [v, - v.scale] if hasattr(v, - 'scale') else v - for k, - v in dict(rep_sd).items() if transformer_name in k - }), - f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{rank:0>2d}.pt') + keys = list(rep_sd.keys()) + partition_size = (len(keys) // num_partitions + 1) + for m in range(num_partitions): + torch.save( + OrderedDict({ + k: [rep_sd[k], + rep_sd[k].scale] if hasattr(rep_sd[k], + 'scale') else rep_sd[k] + for k in keys[m * partition_size:(m + 1) * partition_size] + if transformer_name in k + }), + f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{rank:0>2d}_{m:0>2d}.pt') return replaced_module From cb72d9cea07a3714cd115cab8e35acdb5f9397fa Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sat, 20 Aug 2022 06:36:00 +0500 Subject: [PATCH 09/16] move quantizer declaration before inference branch --- deepspeed/module_inject/replace_module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index cbe409cd6fbe..b1ba2a9dbd5a 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -311,6 +311,7 @@ def replace_with_policy(child, #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group) + quantizer = GroupQuantizer(q_int8=quantize) if inference: if moe: ep_world_size = dist.get_world_size() From 57779eff8ce72a89b9db69692a4f8c59d1e7dae6 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 24 Aug 2022 22:42:50 +0500 Subject: [PATCH 10/16] fixing some part to catch up with latest update on HF side --- deepspeed/inference/engine.py | 5 ++ deepspeed/module_inject/replace_module.py | 6 ++- .../inference/transformer_inference.py | 54 +++++++++---------- 3 files changed, 37 insertions(+), 28 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 427ad5d4f9c6..50ffb9bc0e55 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -100,6 +100,7 @@ def __init__(self, self.cuda_graph_created = False self.checkpoint_engine = TorchCheckpointEngine() self._init_quantization_setting(quantization_setting) + self._add_ds_inference_flag() if enable_cuda_graph: assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ @@ -168,6 +169,10 @@ def _get_model_config_generate(self, config): self.config = getattr(self.module, 'config', None) if config is None else config self.generate = getattr(self.module, 'generate', None) + def _add_ds_inference_flag(self): + if hasattr(self.module, 'transformer'): + setattr(self.module.transformer, 'ds_inference', True) + def _create_model_parallel_group(self): # Call the init process if InferenceEngine.inference_mp_group is None: diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index b1ba2a9dbd5a..d9d93a5bde16 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -856,7 +856,11 @@ def replace_fn(child, _policy, layer_id=0): for i in range(len(checkpoint)): - sd = [torch.load(checkpoint[i], map_location='cpu')] + sd = [ + torch.load(os.path.join(base_dir1, + checkpoint[i]), + map_location='cpu') + ] load_model_with_checkpoint( replaced_module, sd, diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index 2cc6714ebd4b..8c21f5d0afcd 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -206,16 +206,6 @@ def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] - key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) - value_layer = torch.cat((past_value.type_as(value_layer), - value_layer), - dim=1) - - presents = (key_layer, value_layer) - # [batch_size, head_dim, q_length, k_length] output_size = (query_layer.size(0), query_layer.size(2), @@ -223,24 +213,39 @@ def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): key_layer.size(1)) # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] query_layer = query_layer.transpose(1, - 0).reshape( - output_size[2], + 2).reshape( output_size[0] * output_size[1], + output_size[2], -1) # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] key_layer = key_layer.transpose(1, - 0).reshape(output_size[3], - output_size[0] * output_size[1], - -1) + 2).reshape(output_size[0] * output_size[1], + output_size[3], + -1).transpose(-1, + -2) + value_layer = value_layer.transpose(1, + 2).reshape( + output_size[0] * output_size[1], + output_size[3], + -1) + if layer_past is not None: + past_key, past_value = layer_past + #if config.layer_id == 0: + # import pdb;pdb.set_trace() + # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1) + value_layer = torch.cat((past_value.type_as(value_layer), + value_layer), + dim=-2) + presents = (key_layer, value_layer) # Raw attention scores. [batch_size * num_heads, q_length, k_length] - matmul_result = torch.matmul(query_layer.transpose(1, - 0), - key_layer.transpose(1, - 0).transpose(1, - 2)) + matmul_result = torch.matmul(query_layer, key_layer) # change view to [batch_size, num_heads, q_length, k_length] - attention_scores = matmul_result.view(*output_size) + attention_scores = matmul_result.view(output_size[0], + output_size[1], + output_size[2], + -1) offset = dist.get_rank( ) * num_attention_heads_per_partition if dist.is_initialized() else 0 @@ -261,12 +266,7 @@ def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): attention_probs_reshaped = attention_probs.view(*matmul_result.shape) # matmul: [batch_size * num_heads, q_length, head_dim] - context_layer = torch.bmm( - attention_probs_reshaped, - value_layer.transpose(1, - 2).reshape(-1, - value_layer.size(1), - value_layer.size(3))) + context_layer = torch.bmm(attention_probs_reshaped, value_layer) # change view [batch_size, num_heads, q_length, head_dim] context_layer = context_layer.view( From dbcb6ec521309add60f9da28c55b1b75726cc256 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Thu, 25 Aug 2022 12:49:11 +0500 Subject: [PATCH 11/16] reducing the CPU memory usage when loading checkpoint (this solves the issue when there is not enough CPU memory to load large models --- deepspeed/module_inject/load_checkpoint.py | 23 ++++++++++++++-------- deepspeed/module_inject/replace_module.py | 5 +++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index c4cb02075a9a..1b45b377cef3 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -3,6 +3,7 @@ from ..runtime.zero import GatheredParameters from .layers import LinearLayer, Normalize, EmbeddingLayer import torch +import gc def load_model_with_checkpoint(r_module, @@ -33,6 +34,8 @@ def load(module, prefix): sd[0][prefix + 'weight']) if prefix + 'bias' in sd[0].keys(): module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias']) + args = None + gc.collect() def load_transformer_layer(module, prefix): if ckpt_type == "tp": @@ -41,10 +44,10 @@ def load_parameters(module, prefix): for n, p in module.named_parameters(): if prefix + n in sd[0] and len(n.split('.')) == 1: if type(sd[0][prefix + n]) is list: - tmp_data, scale = sd[0][prefix + n] + tmp_data, scale = sd[0][prefix + n].to(torch.cuda.current_device()) scale = scale.to(torch.cuda.current_device()) else: - tmp_data = sd[0][prefix + n] + tmp_data = sd[0][prefix + n].to(torch.cuda.current_device()) scale = None src_shape = tmp_data.shape dst_shape = p.shape @@ -83,13 +86,16 @@ def load_parameters(module, prefix): '''Merging of the checkpoints are not supported when using INT8 checkpoint! \ Please use a as many GPUs as TP-size for the checkpoint''' all_data = [ - sd[j][prefix + n] for j in range(len(sd)) + sd[j][prefix + + n] if type(sd[j][prefix + n]) is list else + sd[j][prefix + n].to(torch.cuda.current_device()) + for j in range(len(sd)) ] - weight_partition = torch.cat( - [(ad[0] if type(ad) is list else ad).to( - torch.cuda.current_device()) - for ad in all_data], - dim=dim) + weight_partition = torch.cat([ + ad[0].to(torch.cuda.current_device()) + if type(ad) is list else ad for ad in all_data + ], + dim=dim) if tmp_data.dtype == torch.int8: scale = torch.cat([ ad[1].to(torch.cuda.current_device()) @@ -230,3 +236,4 @@ def load_module_recursive(module, prefix='', level=0): for sd_ in sd: del sd_ sd = None + gc.collect() diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index e8f6525215b6..5dde87515d8a 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -870,6 +870,7 @@ def replace_fn(child, _policy, layer_id=0): quantizer, ) else: + import gc num_checkpoints = len(ckpt_list) // ckpt_mp_size tp_split_size = (world_size / ckpt_mp_size) sd_offset = int(rank / tp_split_size) @@ -896,6 +897,8 @@ def replace_fn(child, _policy, layer_id=0): ckpt_type, quantizer, int(rank % tp_split_size)) + sds = [None for _ in sds] + gc.collect() if "non_tp" in checkpoint: pbar = tqdm.tqdm( @@ -914,6 +917,8 @@ def replace_fn(child, _policy, layer_id=0): ckpt_type, quantizer, int(rank % tp_split_size)) + sds = [None for _ in sds] + gc.collect() print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") if save_mp_checkpoint_path is not None: From cd80eccbfa290b5da14df4e03709e651a79ca0dd Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 26 Aug 2022 03:09:39 +0500 Subject: [PATCH 12/16] some minor modification to the ckpt names --- deepspeed/module_inject/replace_module.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 5dde87515d8a..138ad79bc0e2 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -939,8 +939,8 @@ def replace_fn(child, _policy, layer_id=0): if dist.is_initialized(): dist.barrier() transformer_name = get_transformer_name(replaced_module) - non_tp_ckpt_name = f'{ckpt_name}-non-tp.pt' - ckpt_files = [non_tp_ckpt_name] #* world_size + non_tp_ckpt_name = f'non-tp.pt' + ckpt_files = [non_tp_ckpt_name] os.makedirs(save_mp_checkpoint_path, exist_ok=True) if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") @@ -952,7 +952,6 @@ def replace_fn(child, _policy, layer_id=0): if transformer_name not in k }), f'{save_mp_checkpoint_path}/{non_tp_ckpt_name}') - #ckpt_files += [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)] config = json.dumps({ 'type': ckpt_name, 'base_dir': f'{save_mp_checkpoint_path}', @@ -960,13 +959,13 @@ def replace_fn(child, _policy, layer_id=0): "non_tp": ckpt_files, "tp": [ - f'{ckpt_name}-tp_{r:0>2d}_{m:0>2d}.pt' - for m in range(num_partitions) for r in range(world_size) + f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions) + for r in range(world_size) ] }, 'version': 1.0, 'parallelization': 'tp', - 'mp_size': world_size + 'tp_size': world_size }) with open(f"{save_mp_checkpoint_path}/{ckpt_name}_ds-inference_config.json", "w") as cfg: From 82a37d6d4393d9ed76a80dcb392c51e48188fd14 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 26 Aug 2022 05:08:25 +0500 Subject: [PATCH 13/16] remove masking and some configuration changes --- deepspeed/inference/engine.py | 9 +++++--- deepspeed/module_inject/load_checkpoint.py | 3 ++- deepspeed/module_inject/replace_module.py | 22 ++++++++++++------- .../inference/transformer_inference.py | 2 -- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 50ffb9bc0e55..a4b57a05f37b 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -100,7 +100,9 @@ def __init__(self, self.cuda_graph_created = False self.checkpoint_engine = TorchCheckpointEngine() self._init_quantization_setting(quantization_setting) - self._add_ds_inference_flag() + + # This is a hack to remove the prepare_mask function on HF side for BLOOM architecture + self.remove_mask_prepare_for_bloom() if enable_cuda_graph: assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ @@ -169,9 +171,10 @@ def _get_model_config_generate(self, config): self.config = getattr(self.module, 'config', None) if config is None else config self.generate = getattr(self.module, 'generate', None) - def _add_ds_inference_flag(self): + def remove_mask_prepare_for_bloom(self): if hasattr(self.module, 'transformer'): - setattr(self.module.transformer, 'ds_inference', True) + if hasattr(self.module.transformer, '_prepare_attn_mask'): + self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask def _create_model_parallel_group(self): # Call the init process diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index 1b45b377cef3..b7f14ca0ebcd 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -44,7 +44,8 @@ def load_parameters(module, prefix): for n, p in module.named_parameters(): if prefix + n in sd[0] and len(n.split('.')) == 1: if type(sd[0][prefix + n]) is list: - tmp_data, scale = sd[0][prefix + n].to(torch.cuda.current_device()) + tmp_data, scale = sd[0][prefix + n] + tmp_data = tmp_data.to(torch.cuda.current_device()) scale = scale.to(torch.cuda.current_device()) else: tmp_data = sd[0][prefix + n].to(torch.cuda.current_device()) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 138ad79bc0e2..1179b425d5f5 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -953,8 +953,10 @@ def replace_fn(child, _policy, layer_id=0): }), f'{save_mp_checkpoint_path}/{non_tp_ckpt_name}') config = json.dumps({ - 'type': ckpt_name, - 'base_dir': f'{save_mp_checkpoint_path}', + 'type': + ckpt_name, + 'base_dir': + f'{save_mp_checkpoint_path}', 'checkpoints': { "non_tp": ckpt_files, @@ -963,12 +965,16 @@ def replace_fn(child, _policy, layer_id=0): for r in range(world_size) ] }, - 'version': 1.0, - 'parallelization': 'tp', - 'tp_size': world_size + 'version': + 1.0, + 'parallelization': + 'tp', + 'mp_size': + world_size, + 'dtype': + 'int8' if quantize else ('float16' if fp16 else 'float32') }) - with open(f"{save_mp_checkpoint_path}/{ckpt_name}_ds-inference_config.json", - "w") as cfg: + with open(f"{save_mp_checkpoint_path}/ds-inference_config.json", "w") as cfg: cfg.write(config) rep_sd = replaced_module.state_dict() @@ -986,7 +992,7 @@ def replace_fn(child, _policy, layer_id=0): for k in keys[m * partition_size:(m + 1) * partition_size] if transformer_name in k }), - f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{rank:0>2d}_{m:0>2d}.pt') + f'{save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt') return replaced_module diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index 8c21f5d0afcd..fa28a34f04a2 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -230,8 +230,6 @@ def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): -1) if layer_past is not None: past_key, past_value = layer_past - #if config.layer_id == 0: - # import pdb;pdb.set_trace() # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1) value_layer = torch.cat((past_value.type_as(value_layer), From 9d1265615a35ea6031fe6d8c4111fb020998086f Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 26 Aug 2022 05:20:21 +0500 Subject: [PATCH 14/16] remove dead code --- csrc/transformer/inference/csrc/gelu.cu | 175 ------------------ .../inference/includes/custom_cuda_layers.h | 19 -- 2 files changed, 194 deletions(-) diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index f0ba2e2c07d9..34f78101f33d 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -538,178 +538,3 @@ template void launch_moe_res_matmul(__half* residual, int seq_len, int hidden_dim, cudaStream_t stream); - -__device__ void quantize_kernel_glue(float2* data, - unsigned cnt, - int8_t* vals_int, - float* q_scale_d, - int num_bits, - int group_size) -{ - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int gid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - int warp_num = blockDim.x >> 5; - int id = threadIdx.x; - - float* vals_int_cast = reinterpret_cast(vals_int); - - __half max = -10000.0; - int bid = blockIdx.x; - unsigned group_index; - for (int i = 0; i < cnt; i++) { - __half* data_h = reinterpret_cast<__half*>(&data[i]); - if (__hgt(__habs(data_h[0]), max)) max = __habs(data_h[0]); - if (__hgt(__habs(data_h[1]), max)) max = __habs(data_h[1]); - if (__hgt(__habs(data_h[2]), max)) max = __habs(data_h[2]); - if (__hgt(__habs(data_h[3]), max)) max = __habs(data_h[3]); - } - -#pragma unroll - for (int i = 1; i < WARP_SIZE; i <<= 1) { - auto temp = g.shfl_xor(max, i); - if (__hgt(temp, max)) max = temp; - } - __shared__ __half partialMax[WARP_SIZE]; - - if (lane == 0) partialMax[gid] = max; - - b.sync(); - - max = partialMax[lane]; - - b.sync(); - -#pragma unroll - for (int i = 1; i < warp_num; i <<= 1) { - auto temp = g.shfl_xor(max, i); - if (__hgt(temp, max)) max = temp; - } - max = g.shfl(max, 0); - - float q_scale = (1 << num_bits) / (2 * (float)max); - - group_index = threadIdx.x + bid * group_size; - for (int i = 0; i < cnt; i++) { - float q_data_int; // = (float)(int)(1 << 8 | 1 << 16 | 1 << 24 | 1); - int8_t* q_data_8 = reinterpret_cast(&q_data_int); - __half* data_h = reinterpret_cast<__half*>(&data[i]); - int32_t data_f[4]; - data_f[0] = round((float)data_h[0] * q_scale); - data_f[1] = round((float)data_h[1] * q_scale); - data_f[2] = round((float)data_h[2] * q_scale); - data_f[3] = round((float)data_h[3] * q_scale); - q_data_8[0] = data_f[0] > 127 ? 127 : (data_f[0] < -128 ? -128 : data_f[0]); - q_data_8[1] = data_f[1] > 127 ? 127 : (data_f[1] < -128 ? -128 : data_f[1]); - q_data_8[2] = data_f[2] > 127 ? 127 : (data_f[2] < -128 ? -128 : data_f[2]); - q_data_8[3] = data_f[3] > 127 ? 127 : (data_f[3] < -128 ? -128 : data_f[3]); - vals_int_cast[group_index] = q_data_int; - group_index += (blockDim.x); - } - if (threadIdx.x == 0) q_scale_d[blockIdx.x] = 1 / q_scale; -} -__global__ void fused_bias_gelu_int8(int8_t* output, - float* scales, - __half* input, - const __half* bias, - int total_count, - int intermediate_size) -{ -#if __CUDA_ARCH__ >= 700 - - float2* input_cast = reinterpret_cast(input); - const float2* bias_cast = reinterpret_cast(bias); - - int offset = blockIdx.x * intermediate_size; - int id = threadIdx.x; - float2 vals_vec[8]; - unsigned cnt = 0; - while (id < intermediate_size) { - vals_vec[cnt] = input_cast[offset + id]; - float2 bias_vec = bias_cast[id]; - - __half2* vals_half = reinterpret_cast<__half2*>(vals_vec + cnt); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); - - low_data.x += low_bias.x; - low_data.y += low_bias.y; - high_data.x += high_bias.x; - high_data.y += high_bias.y; - - low_data.x = gelu(low_data.x); - low_data.y = gelu(low_data.y); - high_data.x = gelu(high_data.x); - high_data.y = gelu(high_data.y); - - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); - - // input_cast[offset + id] = vals_vec; - id += blockDim.x; - cnt++; - } - quantize_kernel_glue(vals_vec, cnt, output, scales, 8, intermediate_size); -#endif -} -__global__ void quantize_int8(int8_t* output, - float* scales, - __half* input, - int total_count, - int intermediate_size) -{ - float2* input_cast = reinterpret_cast(input); - - int offset = blockIdx.x * intermediate_size; - int id = threadIdx.x; - float2 vals_vec[8]; - unsigned cnt = 0; - while (id < intermediate_size) { - vals_vec[cnt] = input_cast[offset + id]; - - id += blockDim.x; - cnt++; - } - quantize_kernel_glue(vals_vec, cnt, output, scales, 8, intermediate_size); -} - -void launch_bias_gelu_int8(int8_t* output, - float* scales, - __half* input, - const __half* bias, - int intermediate_size, - int batch_size, - cudaStream_t stream) -{ - int total_count = batch_size * (intermediate_size / 4); - int threads = 1024; // intermediate_size / iterations / 4; - dim3 block_dims(threads); - dim3 grid_dims(batch_size); // (batch_size); - - fused_bias_gelu_int8<<>>( - output, scales, input, bias, total_count, intermediate_size / 4); -} - -void launch_me(int8_t* output, - float* scales, - __half* input, - int intermediate_size, - int batch_size, - cudaStream_t stream) -{ - int total_count = batch_size * (intermediate_size / 4); - int threads = 1024; // intermediate_size / iterations / 4; - dim3 block_dims(threads); - dim3 grid_dims(batch_size); // (batch_size); - - quantize_int8<<>>( - output, scales, input, total_count, intermediate_size / 4); -} diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index 6f3da325234a..bfa5d3a79b9e 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -175,22 +175,3 @@ void launch_bias_add_transform_0213(T* outputs, bool rotate_every_two, cudaStream_t stream, int trans_count); - -void run_gemm(void* A, - void* B, - void* C, - void* a, - void* aa, - int M, - int N, - int K, - int groups, - int groups1, - cudaStream_t stream); - -void launch_me(int8_t* output, - float* scales, - __half* input, - int intermediate_size, - int batch_size, - cudaStream_t stream); From b17a3b5926787d772e8f590c3de9a5252472bc66 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sun, 28 Aug 2022 11:08:10 +0500 Subject: [PATCH 15/16] fix some issue with int8 ckpt-loading --- deepspeed/module_inject/load_checkpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index b7f14ca0ebcd..560f1bc83bc7 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -45,7 +45,7 @@ def load_parameters(module, prefix): if prefix + n in sd[0] and len(n.split('.')) == 1: if type(sd[0][prefix + n]) is list: tmp_data, scale = sd[0][prefix + n] - tmp_data = tmp_data.to(torch.cuda.current_device()) + tmp_data = tmp_data scale = scale.to(torch.cuda.current_device()) else: tmp_data = sd[0][prefix + n].to(torch.cuda.current_device()) @@ -62,7 +62,8 @@ def load_parameters(module, prefix): transpose(tmp_data) if weight_quantizer. q_int8 else tmp_data) else: - p = tmp_data + p = torch.nn.parameter.Parameter(tmp_data, + requires_grad=False) p.scale = scale setattr(module, n, p) else: From f3f4b1dda9f9ba09042b8b63bbee475c9364bc56 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 31 Aug 2022 04:21:47 +0500 Subject: [PATCH 16/16] change the mp_size to tp_size at inference config & add some doc-string at init_inference --- deepspeed/__init__.py | 14 +++++++++++++- deepspeed/module_inject/replace_module.py | 5 +++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 0c5158562126..8d3fa725164c 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -279,7 +279,19 @@ def init_inference(model, of groups used in quantization. A tuple is passed in if we want to mention that there is extra-grouping for the MLP part of a Transformer layer (e.g. (True, 8) shows we quantize the model using 8 groups for all the network except the MLP part that we use 8 extra grouping). - replace_with_kernel_inject: If set we inject kernel as we initialize the inference-engine + replace_with_kernel_inject: this flag need to be set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. Otherwise, + the injection_dict provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection) + return_tuple: Specify whether or not the transformer layers need to return a tuple or a Tensor. It is set to True by default (returning a tuple). + ep_size: The expert-parallelism size which is used for partitioning the experts across the GPUs in the expert-parallel group. + moe: Specify if the type of Transformer is MoE. It is set to False by default. + moe_experts: The global number of experts used in an MoE layer. + moe_type: Specify the type of MoE layer. We have two types of MoE layer: 'Standard' and 'Residual'. It is set to 'Standard' type by default. + args: All the arguments used for launching the inference api that can be useful at the inference-engine for injecting the optimizations. + enable_cuda_graph: use this flag for capturing the CUDA-Graph of the inference ops, so that it can run faster using the graph replay method, + this is set to False by default + save_mp_checkpoint_path: The path for which we want to save the loaded model with a checkpoint. This feature is used for adjusting the + parallelism degree to help alleviate the model loading overhead. It does not save any new checkpoint if no path is passed. + base_dir: This shows the root directory under which all the checkpoint files exists. This can be passed through the json config too. Returns: A deepspeed.InferenceEngine wrapped model. diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 8f22ce0bad14..ccda3c8132b1 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -848,7 +848,8 @@ def replace_fn(child, _policy, layer_id=0): checkpoint = checkpoint_dict['checkpoints'] ckpt_list = checkpoint["tp"] if type(checkpoint) is dict else checkpoint ckpt_type = checkpoint_dict.get('parallelization', 'pp') - ckpt_mp_size = checkpoint_dict.get('mp_size', len(ckpt_list)) + ckpt_mp_size = checkpoint_dict.get('tp_size', len(ckpt_list)) + ckpt_mp_size = checkpoint_dict.get('mp_size', ckpt_mp_size) base_dir1 = checkpoint_dict.get('base_dir', base_dir) if ckpt_type == 'pp' and type(checkpoint) is list: @@ -969,7 +970,7 @@ def replace_fn(child, _policy, layer_id=0): 1.0, 'parallelization': 'tp', - 'mp_size': + 'tp_size': world_size, 'dtype': 'int8' if quantize else ('float16' if fp16 else 'float32')