Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ds-inference Int8 support through ZeroQuant technology #2217

Merged
merged 25 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
cf2fe01
Fix the layer-past for GPT based models
Aug 8, 2022
c2cf304
add the Int8 support for ds-inference using ZeroQuant technology
Aug 13, 2022
d98f1f9
fixing some issue with loading checkpoint and bias-add
Aug 15, 2022
ebc82bb
adding the logic to store/restore scale for INT8 checkpoint
Aug 15, 2022
43a7023
add empty quantization scale for different models to run with fp16
Aug 15, 2022
00aa188
Empty-Commit
Aug 15, 2022
9bed645
Merge branch 'master' into ds-inference/ZeroQuant-Int8
RezaYazdaniAminabadi Aug 15, 2022
84e0d03
fix sevral issues after merging with master
Aug 18, 2022
f6cb028
several fixes for generating the INT8 sharded checkpoint
Aug 19, 2022
d47bea6
Merge branch 'master' into ds-inference/ZeroQuant-Int8
RezaYazdaniAminabadi Aug 19, 2022
cb72d9c
move quantizer declaration before inference branch
Aug 20, 2022
32b9322
Merge branch 'master' into ds-inference/ZeroQuant-Int8
RezaYazdaniAminabadi Aug 24, 2022
57779ef
fixing some part to catch up with latest update on HF side
Aug 24, 2022
f4e48e6
Merge branch 'ds-inference/ZeroQuant-Int8' of github.com:microsoft/De…
Aug 24, 2022
dbcb6ec
reducing the CPU memory usage when loading checkpoint (this solves th…
Aug 25, 2022
cd80ecc
some minor modification to the ckpt names
Aug 25, 2022
82a37d6
remove masking and some configuration changes
Aug 26, 2022
9d12656
remove dead code
Aug 26, 2022
4ae356e
Merge branch 'master' into ds-inference/ZeroQuant-Int8
jeffra Aug 26, 2022
d7ff364
Merge branch 'master' into ds-inference/ZeroQuant-Int8
RezaYazdaniAminabadi Aug 28, 2022
b17a3b5
fix some issue with int8 ckpt-loading
Aug 28, 2022
a541e52
Merge branch 'master' into ds-inference/ZeroQuant-Int8
RezaYazdaniAminabadi Aug 29, 2022
2845bad
Merge branch 'master' into ds-inference/ZeroQuant-Int8
RezaYazdaniAminabadi Aug 30, 2022
c77f5e0
Merge branch 'master' into ds-inference/ZeroQuant-Int8
RezaYazdaniAminabadi Aug 30, 2022
f3f4b1d
change the mp_size to tp_size at inference config & add some doc-stri…
Aug 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions csrc/transformer/inference/csrc/dequantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,90 @@ template void launch_dequantize<__half>(__half*,
unsigned,
unsigned,
cudaStream_t);

__global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int hidden_dim,
unsigned merge_hidden,
int cnt)
{
}

__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned hidden_dim,
unsigned merge_hidden,
int cnt)
{
unsigned bid = blockIdx.x * gridDim.y + blockIdx.y;
unsigned tid = threadIdx.x;

float local_scale = qscale[blockIdx.x];

const float* input_cast = reinterpret_cast<const float*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);

input_cast += bid * merge_hidden;
output_cast += bid * merge_hidden;

for (int c = 0; c < cnt; c++) {
if (tid < merge_hidden) {
float q = input_cast[tid];
int8_t* q_int8 = (int8_t*)&q;

float2 q_f;
__half* q_h = (__half*)&q_f;

q_h[0] = __float2half(local_scale * (float)q_int8[0]);
q_h[1] = __float2half(local_scale * (float)q_int8[1]);
q_h[2] = __float2half(local_scale * (float)q_int8[2]);
q_h[3] = __float2half(local_scale * (float)q_int8[3]);
// q_h[4] = __float2half(local_scale * (float)q_int8[4]);
// q_h[5] = __float2half(local_scale * (float)q_int8[5]);
// q_h[6] = __float2half(local_scale * (float)q_int8[6]);
// q_h[7] = __float2half(local_scale * (float)q_int8[7]);
output_cast[tid] = q_f;
tid += blockDim.x;
}
}
}

template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
cudaStream_t stream)
{
unsigned threads = 1024;
hidden_dim /= 4;
unsigned hid_cnt = threads / hidden_dim;
unsigned thd_cnt = (hidden_dim - 1) / threads + 1;
hid_cnt = hid_cnt > 0 ? hid_cnt : 1;

unsigned blocks = output_size / hid_cnt / groups;
dim3 block_dims(threads);
dim3 grid_dims(groups, blocks);

dequantize_kernel<<<grid_dims, block_dims, 0, stream>>>(
output, input, qscale, hidden_dim, hid_cnt * hidden_dim, thd_cnt);
}

template void launch_dequantize<float>(float*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
cudaStream_t);
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
cudaStream_t);
1 change: 1 addition & 0 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "custom_cuda_layers.h"

namespace cg = cooperative_groups;
#define MAX_CAP 4
#define MAX_SEQ 2048

Expand Down
Loading