From efcb3b5e5a5cd39add9aab1e8f70e0c6101ef37a Mon Sep 17 00:00:00 2001 From: GuanhuaWang Date: Sun, 18 Sep 2022 05:00:56 +0000 Subject: [PATCH 01/11] mem access for quantize kernel --- csrc/quantization/quantizer.cu | 57 +++++++++++--------------- tests/unit/ops/quantizer/test_quant.py | 2 +- 2 files changed, 25 insertions(+), 34 deletions(-) diff --git a/csrc/quantization/quantizer.cu b/csrc/quantization/quantizer.cu index 37883410e976..a026449526e6 100644 --- a/csrc/quantization/quantizer.cu +++ b/csrc/quantization/quantizer.cu @@ -1,5 +1,6 @@ #include #include "custom_cuda_layers.h" +#include "memory_access_utils.h" namespace cg = cooperative_groups; @@ -7,37 +8,38 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) { #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + cg::thread_block b = cg::this_thread_block(); //tb + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); //warp, 32 not optimal for AMD which should be 64. int gid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; int warp_num = blockDim.x >> 5; int id = threadIdx.x; - float2* vals_cast = reinterpret_cast(vals); + constexpr int granularity = 8; + constexpr int vals_per_access = granularity / sizeof(__half); - float2 data[MAX_REG]; + __half data[MAX_REG * vals_per_access]; int group_id = blockIdx.x; { - int group_index = id; + int group_index = id * vals_per_access; int reg_count = 0; - int offset = group_id * group_size; + int offset = group_id * group_size * vals_per_access; float max = -10000.0; while (group_index < group_size && reg_count < MAX_REG) { - data[reg_count] = vals_cast[offset + group_index]; - __half* data_h = reinterpret_cast<__half*>(&data[reg_count]); + mem_access::load_global(data + (reg_count*vals_per_access), vals + offset + group_index); - if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]); - if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]); - if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]); - if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]); +#pragma unroll + for(int i=0; i max) max = abs((float)data[reg_count + i]); + } - group_index += blockDim.x; + group_index += blockDim.x * vals_per_access; reg_count++; + } #pragma unroll @@ -63,30 +65,19 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) float q_scale = (1 << num_bits) / (2 * max + 1e-5); float q_scale_inv = 1 / q_scale; + for (int i = 0; i < reg_count; i++) { - group_index = i * blockDim.x + id; + group_index = (i * blockDim.x + id) * vals_per_access; if (group_index < group_size) { - __half2* data_h = reinterpret_cast<__half2*>(&data[i]); - float2 q_data[2]; - q_data[0] = __half22float2(data_h[0]); - q_data[1] = __half22float2(data_h[1]); - - float2 q_data_int[2]; - q_data_int[0].x = roundf(q_data[0].x * q_scale); - q_data_int[0].y = roundf(q_data[0].y * q_scale); - q_data_int[1].x = roundf(q_data[1].x * q_scale); - q_data_int[1].y = roundf(q_data[1].y * q_scale); - - q_data_int[0].x *= q_scale_inv; - q_data_int[0].y *= q_scale_inv; - q_data_int[1].x *= q_scale_inv; - q_data_int[1].y *= q_scale_inv; - - data_h[0] = __float22half2_rn(q_data_int[0]); - data_h[1] = __float22half2_rn(q_data_int[1]); +#pragma unroll + for( int j = 0; j< vals_per_access; j++){ + float q_data; + q_data = __half2float(data[i* vals_per_access+j]); - vals_cast[offset + group_index] = data[i]; + data[i*vals_per_access+j] = __float2half_rn(roundf(q_data * q_scale) * q_scale_inv); + } + mem_access::store_global(vals + offset + group_index, data + (i*vals_per_access)); } } } diff --git a/tests/unit/ops/quantizer/test_quant.py b/tests/unit/ops/quantizer/test_quant.py index ea6b35860873..35f05aa82b99 100644 --- a/tests/unit/ops/quantizer/test_quant.py +++ b/tests/unit/ops/quantizer/test_quant.py @@ -7,7 +7,7 @@ def allclose(x, y): assert x.dtype == y.dtype - rtol, atol = {torch.float32: (2e-1, 5e-2), torch.float16: (2e-1, 5e-2)}[x.dtype] + rtol, atol = {torch.float32: (2e-1, 5e-1), torch.float16: (2e-1, 5e-1)}[x.dtype] return torch.allclose(x, y, rtol=rtol, atol=atol) From b4f800489ff762f5d570378f97308eea7620bc36 Mon Sep 17 00:00:00 2001 From: GuanhuaWang Date: Sun, 18 Sep 2022 05:04:48 +0000 Subject: [PATCH 02/11] format --- csrc/quantization/quantizer.cu | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/csrc/quantization/quantizer.cu b/csrc/quantization/quantizer.cu index a026449526e6..ead11e6852c4 100644 --- a/csrc/quantization/quantizer.cu +++ b/csrc/quantization/quantizer.cu @@ -8,8 +8,9 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) { #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) - cg::thread_block b = cg::this_thread_block(); //tb - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); //warp, 32 not optimal for AMD which should be 64. + cg::thread_block b = cg::this_thread_block(); // tb + cg::thread_block_tile<32> g = + cg::tiled_partition<32>(b); // warp, 32 not optimal for AMD which should be 64. int gid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; @@ -30,16 +31,16 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) float max = -10000.0; while (group_index < group_size && reg_count < MAX_REG) { - mem_access::load_global(data + (reg_count*vals_per_access), vals + offset + group_index); + mem_access::load_global(data + (reg_count * vals_per_access), + vals + offset + group_index); #pragma unroll - for(int i=0; i max) max = abs((float)data[reg_count + i]); } group_index += blockDim.x * vals_per_access; reg_count++; - } #pragma unroll @@ -67,17 +68,18 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) float q_scale_inv = 1 / q_scale; for (int i = 0; i < reg_count; i++) { - group_index = (i * blockDim.x + id) * vals_per_access; + group_index = (i * blockDim.x + id) * vals_per_access; if (group_index < group_size) { - #pragma unroll - for( int j = 0; j< vals_per_access; j++){ + for (int j = 0; j < vals_per_access; j++) { float q_data; - q_data = __half2float(data[i* vals_per_access+j]); + q_data = __half2float(data[i * vals_per_access + j]); - data[i*vals_per_access+j] = __float2half_rn(roundf(q_data * q_scale) * q_scale_inv); + data[i * vals_per_access + j] = + __float2half_rn(roundf(q_data * q_scale) * q_scale_inv); } - mem_access::store_global(vals + offset + group_index, data + (i*vals_per_access)); + mem_access::store_global(vals + offset + group_index, + data + (i * vals_per_access)); } } } From 7b66d552cc990f6cc2f7bb729c3256f498325902 Mon Sep 17 00:00:00 2001 From: GuanhuaWang Date: Sun, 18 Sep 2022 06:35:06 +0000 Subject: [PATCH 03/11] format fp32 --- csrc/quantization/quantizer.cu | 52 ++++++++++++++++------------------ 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/csrc/quantization/quantizer.cu b/csrc/quantization/quantizer.cu index ead11e6852c4..f3cd8808fb12 100644 --- a/csrc/quantization/quantizer.cu +++ b/csrc/quantization/quantizer.cu @@ -96,28 +96,31 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits) int warp_num = blockDim.x >> 5; int id = threadIdx.x; - float4* vals_cast = reinterpret_cast(vals); + constexpr int granularity = 16; + constexpr int vals_per_access = granularity / sizeof(float); - float4 data[MAX_REG]; + float data[MAX_REG * vals_per_access]; int bid = blockIdx.x; - int group_index = bid * group_size + id; + int group_index = id * vals_per_access; + int reg_count = 0; + int offset = bid * group_size * vals_per_access; + float max = -10000.0; - while (id < group_size && reg_count < MAX_REG) { - float4 data_reg = vals_cast[group_index]; - data[reg_count] = data_reg; + while (group_index < group_size && reg_count < MAX_REG) { + mem_access::load_global(data + (reg_count * vals_per_access), + vals + offset + group_index); - if (abs(data_reg.x) > max) max = abs(data_reg.x); - if (abs(data_reg.y) > max) max = abs(data_reg.y); - if (abs(data_reg.z) > max) max = abs(data_reg.z); - if (abs(data_reg.w) > max) max = abs(data_reg.w); +#pragma unroll + for (int i = 0; i < vals_per_access; i++) { + if (abs(data[reg_count + i]) > max) max = abs(data[reg_count + i]); + } - group_index += blockDim.x; - id += blockDim.x; + group_index += blockDim.x * vals_per_access; reg_count++; } id = threadIdx.x; @@ -147,23 +150,16 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits) float q_scale = (1 << num_bits) / (2 * max + 1e-5); float q_scale_inv = 1 / q_scale; for (int i = 0; i < reg_count; i++) { - group_index = i * blockDim.x + id; + group_index = (i * blockDim.x + id) * vals_per_access; if (group_index < group_size) { - float4 q_data; - q_data = data[i]; - - float4 q_data_int; - q_data_int.x = roundf(q_data.x * q_scale); - q_data_int.y = roundf(q_data.y * q_scale); - q_data_int.w = roundf(q_data.w * q_scale); - q_data_int.z = roundf(q_data.z * q_scale); - - q_data.x = q_data_int.x * q_scale_inv; - q_data.y = q_data_int.y * q_scale_inv; - q_data.w = q_data_int.w * q_scale_inv; - q_data.z = q_data_int.z * q_scale_inv; - - vals_cast[group_index + bid * group_size] = q_data; +#pragma unroll + for (int j = 0; j < vals_per_access; j++) { + float q_data; + q_data = roundf(data[i * vals_per_access + j] * q_scale) * q_scale_inv; + data[i * vals_per_access + j] = q_data; + } + mem_access::store_global(vals + offset + group_index, + data + (i * vals_per_access)); } } } From 7c35302de314937d52a7b6553361d3a57580247d Mon Sep 17 00:00:00 2001 From: GuanhuaWang Date: Tue, 20 Sep 2022 20:11:47 +0000 Subject: [PATCH 04/11] modify quant kernel --- csrc/quantization/quantizer.cu | 49 +++++++++++++++++----------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/csrc/quantization/quantizer.cu b/csrc/quantization/quantizer.cu index f3cd8808fb12..7480c7f1b90e 100644 --- a/csrc/quantization/quantizer.cu +++ b/csrc/quantization/quantizer.cu @@ -17,30 +17,30 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) int warp_num = blockDim.x >> 5; int id = threadIdx.x; - constexpr int granularity = 8; + constexpr int granularity = 16; constexpr int vals_per_access = granularity / sizeof(__half); - __half data[MAX_REG * vals_per_access]; + __half data[vals_per_access]; int group_id = blockIdx.x; { - int group_index = id * vals_per_access; + int thread_index = id * vals_per_access; int reg_count = 0; int offset = group_id * group_size * vals_per_access; float max = -10000.0; - while (group_index < group_size && reg_count < MAX_REG) { - mem_access::load_global(data + (reg_count * vals_per_access), - vals + offset + group_index); + while (thread_index < group_size ) { + mem_access::load_global(data + vals_per_access, + vals + thread_index); #pragma unroll for (int i = 0; i < vals_per_access; i++) { - if (abs((float)data[reg_count + i]) > max) max = abs((float)data[reg_count + i]); + if (abs((float)data[i]) > max) max = abs((float)data[i]); } - group_index += blockDim.x * vals_per_access; - reg_count++; + mem_access::store_global(vals + thread_index, data); + thread_index += blockDim.x * vals_per_access; } #pragma unroll @@ -67,21 +67,22 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) float q_scale = (1 << num_bits) / (2 * max + 1e-5); float q_scale_inv = 1 / q_scale; - for (int i = 0; i < reg_count; i++) { - group_index = (i * blockDim.x + id) * vals_per_access; - if (group_index < group_size) { + + while (thread_index < group_size) { + mem_access::load_global(data, + vals + thread_index); #pragma unroll for (int j = 0; j < vals_per_access; j++) { float q_data; - q_data = __half2float(data[i * vals_per_access + j]); + q_data = __half2float(data[j]); - data[i * vals_per_access + j] = + data[j] = __float2half_rn(roundf(q_data * q_scale) * q_scale_inv); } - mem_access::store_global(vals + offset + group_index, - data + (i * vals_per_access)); + mem_access::store_global(vals + thread_index, + data); + thread_index += blockDim.x * vals_per_access; } - } } #endif } @@ -103,7 +104,7 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits) int bid = blockIdx.x; - int group_index = id * vals_per_access; + int thread_index = id * vals_per_access; int reg_count = 0; @@ -111,16 +112,16 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits) float max = -10000.0; - while (group_index < group_size && reg_count < MAX_REG) { + while (thread_index < group_size && reg_count < MAX_REG) { mem_access::load_global(data + (reg_count * vals_per_access), - vals + offset + group_index); + vals + offset + thread_index); #pragma unroll for (int i = 0; i < vals_per_access; i++) { if (abs(data[reg_count + i]) > max) max = abs(data[reg_count + i]); } - group_index += blockDim.x * vals_per_access; + thread_index += blockDim.x * vals_per_access; reg_count++; } id = threadIdx.x; @@ -150,15 +151,15 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits) float q_scale = (1 << num_bits) / (2 * max + 1e-5); float q_scale_inv = 1 / q_scale; for (int i = 0; i < reg_count; i++) { - group_index = (i * blockDim.x + id) * vals_per_access; - if (group_index < group_size) { + thread_index = (i * blockDim.x + id) * vals_per_access; + if (thread_index < group_size) { #pragma unroll for (int j = 0; j < vals_per_access; j++) { float q_data; q_data = roundf(data[i * vals_per_access + j] * q_scale) * q_scale_inv; data[i * vals_per_access + j] = q_data; } - mem_access::store_global(vals + offset + group_index, + mem_access::store_global(vals + offset + thread_index, data + (i * vals_per_access)); } } From d645531bd565be2d69f40c0cbc806af5005cda79 Mon Sep 17 00:00:00 2001 From: GuanhuaWang Date: Tue, 20 Sep 2022 20:13:12 +0000 Subject: [PATCH 05/11] modify quant kernel2 --- csrc/quantization/quantizer.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/quantizer.cu b/csrc/quantization/quantizer.cu index 7480c7f1b90e..581cbbab70b8 100644 --- a/csrc/quantization/quantizer.cu +++ b/csrc/quantization/quantizer.cu @@ -31,7 +31,7 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) float max = -10000.0; while (thread_index < group_size ) { - mem_access::load_global(data + vals_per_access, + mem_access::load_global(data, vals + thread_index); #pragma unroll From 6e95dfdc04b5cbbe25dfcd919f81f398b53566a7 Mon Sep 17 00:00:00 2001 From: GuanhuaWang Date: Tue, 20 Sep 2022 20:20:28 +0000 Subject: [PATCH 06/11] modify format --- csrc/quantization/quantizer.cu | 57 +++++++++++++++------------------- 1 file changed, 25 insertions(+), 32 deletions(-) diff --git a/csrc/quantization/quantizer.cu b/csrc/quantization/quantizer.cu index 581cbbab70b8..8dca2e10c827 100644 --- a/csrc/quantization/quantizer.cu +++ b/csrc/quantization/quantizer.cu @@ -25,14 +25,13 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) int group_id = blockIdx.x; { - int thread_index = id * vals_per_access; + int thread_index = id * vals_per_access; int reg_count = 0; int offset = group_id * group_size * vals_per_access; float max = -10000.0; - while (thread_index < group_size ) { - mem_access::load_global(data, - vals + thread_index); + while (thread_index < group_size) { + mem_access::load_global(data, vals + thread_index); #pragma unroll for (int i = 0; i < vals_per_access; i++) { @@ -67,22 +66,18 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) float q_scale = (1 << num_bits) / (2 * max + 1e-5); float q_scale_inv = 1 / q_scale; - while (thread_index < group_size) { - mem_access::load_global(data, - vals + thread_index); + mem_access::load_global(data, vals + thread_index); #pragma unroll - for (int j = 0; j < vals_per_access; j++) { - float q_data; - q_data = __half2float(data[j]); - - data[j] = - __float2half_rn(roundf(q_data * q_scale) * q_scale_inv); - } - mem_access::store_global(vals + thread_index, - data); - thread_index += blockDim.x * vals_per_access; + for (int j = 0; j < vals_per_access; j++) { + float q_data; + q_data = __half2float(data[j]); + + data[j] = __float2half_rn(roundf(q_data * q_scale) * q_scale_inv); } + mem_access::store_global(vals + thread_index, data); + thread_index += blockDim.x * vals_per_access; + } } #endif } @@ -100,7 +95,7 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits) constexpr int granularity = 16; constexpr int vals_per_access = granularity / sizeof(float); - float data[MAX_REG * vals_per_access]; + float data[vals_per_access]; int bid = blockIdx.x; @@ -113,17 +108,16 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits) float max = -10000.0; while (thread_index < group_size && reg_count < MAX_REG) { - mem_access::load_global(data + (reg_count * vals_per_access), - vals + offset + thread_index); + mem_access::load_global(data, vals + thread_index); #pragma unroll for (int i = 0; i < vals_per_access; i++) { - if (abs(data[reg_count + i]) > max) max = abs(data[reg_count + i]); + if (abs(data[i]) > max) max = abs(data[i]); } thread_index += blockDim.x * vals_per_access; - reg_count++; } + id = threadIdx.x; #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { @@ -150,18 +144,17 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits) float q_scale = (1 << num_bits) / (2 * max + 1e-5); float q_scale_inv = 1 / q_scale; - for (int i = 0; i < reg_count; i++) { - thread_index = (i * blockDim.x + id) * vals_per_access; - if (thread_index < group_size) { + while (thread_index < group_size) { + mem_access::load_global(data, vals + thread_index); #pragma unroll - for (int j = 0; j < vals_per_access; j++) { - float q_data; - q_data = roundf(data[i * vals_per_access + j] * q_scale) * q_scale_inv; - data[i * vals_per_access + j] = q_data; - } - mem_access::store_global(vals + offset + thread_index, - data + (i * vals_per_access)); + for (int j = 0; j < vals_per_access; j++) { + float q_data; + q_data = __half2float(data[j]); + + data[j] = __float2half_rn(roundf(q_data * q_scale) * q_scale_inv); } + mem_access::store_global(vals + thread_index, data); + thread_index += blockDim.x * vals_per_access; } } From 9cb8bef35dd6eff57af112124eff72f4525a2f3a Mon Sep 17 00:00:00 2001 From: GuanhuaWang Date: Thu, 22 Sep 2022 00:21:41 +0000 Subject: [PATCH 07/11] format --- csrc/quantization/quantizer.cu | 110 +++++++++++++------------ tests/unit/ops/quantizer/test_quant.py | 18 ++-- 2 files changed, 66 insertions(+), 62 deletions(-) diff --git a/csrc/quantization/quantizer.cu b/csrc/quantization/quantizer.cu index 8dca2e10c827..41a11b3cfa53 100644 --- a/csrc/quantization/quantizer.cu +++ b/csrc/quantization/quantizer.cu @@ -24,61 +24,61 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) int group_id = blockIdx.x; - { - int thread_index = id * vals_per_access; - int reg_count = 0; - int offset = group_id * group_size * vals_per_access; - float max = -10000.0; - - while (thread_index < group_size) { - mem_access::load_global(data, vals + thread_index); + int thread_index = id * vals_per_access; + int reg_count = 0; + int offset = group_id * group_size; + float max = -10000.0; + for (int thread_index = id * vals_per_access; thread_index < group_size; + thread_index += blockDim.x * vals_per_access) { + mem_access::load_global(data, vals + offset + thread_index); #pragma unroll - for (int i = 0; i < vals_per_access; i++) { - if (abs((float)data[i]) > max) max = abs((float)data[i]); - } - - mem_access::store_global(vals + thread_index, data); - thread_index += blockDim.x * vals_per_access; + for (int i = 0; i < vals_per_access; i++) { + if (abs((float)data[i]) > max) max = abs((float)data[i]); } + } #pragma unroll - for (int i = 1; i < WARP_SIZE; i <<= 1) { - auto temp = g.shfl_xor(max, i); - if (max < temp) max = temp; - } - __shared__ float partialMax[WARP_SIZE]; + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(max, i); + if (max < temp) max = temp; + } + __shared__ float partialMax[WARP_SIZE]; - if (lane == 0) partialMax[gid] = max; + if (lane == 0) partialMax[gid] = max; - b.sync(); + b.sync(); - if (lane < warp_num) max = partialMax[lane]; + if (lane < warp_num) max = partialMax[lane]; #pragma unroll - for (int i = 1; i < WARP_SIZE; i <<= 1) { - auto temp = g.shfl_down(max, i); - if (max < temp) max = temp; - } + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_down(max, i); + if (max < temp) max = temp; + } - max = g.shfl(max, 0); + max = g.shfl(max, 0); - float q_scale = (1 << num_bits) / (2 * max + 1e-5); - float q_scale_inv = 1 / q_scale; + float q_scale = (float)(1 << num_bits) / (2 * max + 1e-5); + float q_scale_inv = 1 / q_scale; + int q_range_max = (1 << (num_bits - 1)) - 1; + int q_range_min = -(1 << (num_bits - 1)); - while (thread_index < group_size) { - mem_access::load_global(data, vals + thread_index); + for (int thread_index = id * vals_per_access; thread_index < group_size; + thread_index += blockDim.x * vals_per_access) { + mem_access::load_global(data, vals + offset + thread_index); #pragma unroll - for (int j = 0; j < vals_per_access; j++) { - float q_data; - q_data = __half2float(data[j]); - - data[j] = __float2half_rn(roundf(q_data * q_scale) * q_scale_inv); - } - mem_access::store_global(vals + thread_index, data); - thread_index += blockDim.x * vals_per_access; + for (int j = 0; j < vals_per_access; j++) { + float q_data; + q_data = __half2float(data[j]); + q_data = __float2int_rn(q_data * q_scale); + q_data = q_data > (q_range_max) ? (q_range_max) + : (q_data < (q_range_min) ? (q_range_min) : q_data); + data[j] = __float2half_rn(q_data * q_scale_inv); } + mem_access::store_global(vals + offset + thread_index, data); } + #endif } @@ -103,22 +103,20 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits) int reg_count = 0; - int offset = bid * group_size * vals_per_access; + int offset = bid * group_size; float max = -10000.0; - while (thread_index < group_size && reg_count < MAX_REG) { - mem_access::load_global(data, vals + thread_index); + for (int thread_index = id * vals_per_access; thread_index < group_size; + thread_index += blockDim.x * vals_per_access) { + mem_access::load_global(data, vals + offset + thread_index); #pragma unroll for (int i = 0; i < vals_per_access; i++) { if (abs(data[i]) > max) max = abs(data[i]); } - - thread_index += blockDim.x * vals_per_access; } - id = threadIdx.x; #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(max, i); @@ -144,17 +142,22 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits) float q_scale = (1 << num_bits) / (2 * max + 1e-5); float q_scale_inv = 1 / q_scale; - while (thread_index < group_size) { - mem_access::load_global(data, vals + thread_index); + + int q_range_max = (1 << (num_bits - 1)) - 1; + int q_range_min = -(1 << (num_bits - 1)); + + for (int thread_index = id * vals_per_access; thread_index < group_size; + thread_index += blockDim.x * vals_per_access) { + mem_access::load_global(data, vals + offset + thread_index); #pragma unroll for (int j = 0; j < vals_per_access; j++) { float q_data; - q_data = __half2float(data[j]); - - data[j] = __float2half_rn(roundf(q_data * q_scale) * q_scale_inv); + q_data = __float2int_rn(data[j] * q_scale); + q_data = q_data > (q_range_max) ? (q_range_max) + : (q_data < (q_range_min) ? (q_range_min) : q_data); + data[j] = roundf(q_data * q_scale_inv); } - mem_access::store_global(vals + thread_index, data); - thread_index += blockDim.x * vals_per_access; + mem_access::store_global(vals + offset + thread_index, data); } } @@ -168,8 +171,7 @@ void launch_quantize_kernel(T* vals, dim3 grid_dim(group_num); dim3 block_dim(1024); - quantize_kernel<<>>( - vals, (total_count / group_num) / 4, num_bits); + quantize_kernel<<>>(vals, total_count / group_num, num_bits); } template void launch_quantize_kernel(float* vals, diff --git a/tests/unit/ops/quantizer/test_quant.py b/tests/unit/ops/quantizer/test_quant.py index 35f05aa82b99..60ecbbc8a755 100644 --- a/tests/unit/ops/quantizer/test_quant.py +++ b/tests/unit/ops/quantizer/test_quant.py @@ -7,7 +7,7 @@ def allclose(x, y): assert x.dtype == y.dtype - rtol, atol = {torch.float32: (2e-1, 5e-1), torch.float16: (2e-1, 5e-1)}[x.dtype] + rtol, atol = {torch.float32: (2e-2, 5e-3), torch.float16: (2e-2, 5e-3)}[x.dtype] return torch.allclose(x, y, rtol=rtol, atol=atol) @@ -19,7 +19,7 @@ def quantize_dequantize_ref(inputs, bit, num_groups=1): input_min = input_flat.amin(-1, keepdim=True) input_max = input_flat.amax(-1, keepdim=True) - scale = q_range / (2 * torch.max(input_min.abs(), input_max.abs())) + scale = q_range / (2 * torch.max(input_min.abs(), input_max.abs() + 1e-5)) input_flat = (input_flat * scale).round().clamp(-q_range // 2, q_range // 2 - 1) # dequantize dequant_flat = torch.t(input_flat.to(torch.int8)) / scale.view(-1).to(torch.float16) @@ -35,22 +35,24 @@ def run_quant_dequant(inputs, groups, bits): @pytest.mark.inference -@pytest.mark.parametrize("tensor_shape", [(8, 8), (128, 256)]) -def test_quant_dequant(tensor_shape): +@pytest.mark.parametrize("tensor_shape", [(16, 4096), (128, 256)]) +@pytest.mark.parametrize("groups", [1, 16]) +def test_quant_dequant(tensor_shape, groups): + input_tensor = torch.rand((tensor_shape), dtype=torch.float16).cuda() # test 8bit quant/dequant on tensor partitioned in 1 group. ref_input_8bit_1group = input_tensor.clone().detach() ds_input_8bit_1group = input_tensor.clone().detach() - ref_out_8bit_1group = quantize_dequantize_ref(ref_input_8bit_1group, 8) + ref_out_8bit_1group = quantize_dequantize_ref(ref_input_8bit_1group, 8, groups) # run_quant_dequant will do quantize then dequantize and return the dequantized value. - ds_out_8bit_1group = run_quant_dequant(ds_input_8bit_1group, 1, 8) + ds_out_8bit_1group = run_quant_dequant(ds_input_8bit_1group, groups, 8) assert (allclose(ds_out_8bit_1group, ref_out_8bit_1group)) # test 4bit quant/dequant on tensor partitioned into 16 groups. # Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG. ref_input_4bit_16group = input_tensor.clone().detach() ds_input_4bit_16group = input_tensor.clone().detach() - ref_out_4bit_16group = quantize_dequantize_ref(ref_input_4bit_16group, 4, 16) - ds_out_4bit_16group = run_quant_dequant(ds_input_4bit_16group, 16, 4) + ref_out_4bit_16group = quantize_dequantize_ref(ref_input_4bit_16group, 4, groups) + ds_out_4bit_16group = run_quant_dequant(ds_input_4bit_16group, groups, 4) assert (allclose(ds_out_4bit_16group, ref_out_4bit_16group)) From 557b508cd67c649c222eb9471d0aac73b5467808 Mon Sep 17 00:00:00 2001 From: GuanhuaWang Date: Thu, 22 Sep 2022 00:41:55 +0000 Subject: [PATCH 08/11] fix comments in pytest --- tests/unit/ops/quantizer/test_quant.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/ops/quantizer/test_quant.py b/tests/unit/ops/quantizer/test_quant.py index 60ecbbc8a755..18d76101a08a 100644 --- a/tests/unit/ops/quantizer/test_quant.py +++ b/tests/unit/ops/quantizer/test_quant.py @@ -36,12 +36,14 @@ def run_quant_dequant(inputs, groups, bits): @pytest.mark.inference @pytest.mark.parametrize("tensor_shape", [(16, 4096), (128, 256)]) +# Test with two tensor shapes as (16, 4096) and (128, 256). @pytest.mark.parametrize("groups", [1, 16]) +# Test with number of quant groups as 1 and 16. +# Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG. def test_quant_dequant(tensor_shape, groups): input_tensor = torch.rand((tensor_shape), dtype=torch.float16).cuda() - # test 8bit quant/dequant on tensor partitioned in 1 group. ref_input_8bit_1group = input_tensor.clone().detach() ds_input_8bit_1group = input_tensor.clone().detach() ref_out_8bit_1group = quantize_dequantize_ref(ref_input_8bit_1group, 8, groups) @@ -49,8 +51,6 @@ def test_quant_dequant(tensor_shape, groups): ds_out_8bit_1group = run_quant_dequant(ds_input_8bit_1group, groups, 8) assert (allclose(ds_out_8bit_1group, ref_out_8bit_1group)) - # test 4bit quant/dequant on tensor partitioned into 16 groups. - # Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG. ref_input_4bit_16group = input_tensor.clone().detach() ds_input_4bit_16group = input_tensor.clone().detach() ref_out_4bit_16group = quantize_dequantize_ref(ref_input_4bit_16group, 4, groups) From 485be19737ef8948eb86674334bc3fe765268440 Mon Sep 17 00:00:00 2001 From: GuanhuaWang Date: Thu, 22 Sep 2022 00:45:07 +0000 Subject: [PATCH 09/11] fix comments in pytest --- tests/unit/ops/quantizer/test_quant.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/unit/ops/quantizer/test_quant.py b/tests/unit/ops/quantizer/test_quant.py index 18d76101a08a..8b14b064b02b 100644 --- a/tests/unit/ops/quantizer/test_quant.py +++ b/tests/unit/ops/quantizer/test_quant.py @@ -44,15 +44,17 @@ def test_quant_dequant(tensor_shape, groups): input_tensor = torch.rand((tensor_shape), dtype=torch.float16).cuda() - ref_input_8bit_1group = input_tensor.clone().detach() - ds_input_8bit_1group = input_tensor.clone().detach() - ref_out_8bit_1group = quantize_dequantize_ref(ref_input_8bit_1group, 8, groups) + # 8 bit quantization. + ref_input_8bit = input_tensor.clone().detach() + ds_input_8bit = input_tensor.clone().detach() + ref_out_8bit = quantize_dequantize_ref(ref_input_8bit, 8, groups) # run_quant_dequant will do quantize then dequantize and return the dequantized value. - ds_out_8bit_1group = run_quant_dequant(ds_input_8bit_1group, groups, 8) - assert (allclose(ds_out_8bit_1group, ref_out_8bit_1group)) - - ref_input_4bit_16group = input_tensor.clone().detach() - ds_input_4bit_16group = input_tensor.clone().detach() - ref_out_4bit_16group = quantize_dequantize_ref(ref_input_4bit_16group, 4, groups) - ds_out_4bit_16group = run_quant_dequant(ds_input_4bit_16group, groups, 4) - assert (allclose(ds_out_4bit_16group, ref_out_4bit_16group)) + ds_out_8bit = run_quant_dequant(ds_input_8bit, groups, 8) + assert (allclose(ds_out_8bit, ref_out_8bit)) + + # 4 bit quantization. + ref_input_4bit = input_tensor.clone().detach() + ds_input_4bit = input_tensor.clone().detach() + ref_out_4bit = quantize_dequantize_ref(ref_input_4bit, 4, groups) + ds_out_4bit = run_quant_dequant(ds_input_4bit, groups, 4) + assert (allclose(ds_out_4bit, ref_out_4bit)) From 91582f2c3e2bbfc78b0f72d30fb072c0696385cd Mon Sep 17 00:00:00 2001 From: GuanhuaWang Date: Thu, 22 Sep 2022 01:50:19 +0000 Subject: [PATCH 10/11] format --- tests/unit/ops/quantizer/test_quant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/ops/quantizer/test_quant.py b/tests/unit/ops/quantizer/test_quant.py index 8b14b064b02b..02ef89b58dda 100644 --- a/tests/unit/ops/quantizer/test_quant.py +++ b/tests/unit/ops/quantizer/test_quant.py @@ -44,7 +44,7 @@ def test_quant_dequant(tensor_shape, groups): input_tensor = torch.rand((tensor_shape), dtype=torch.float16).cuda() - # 8 bit quantization. + # 8-bit quantization. ref_input_8bit = input_tensor.clone().detach() ds_input_8bit = input_tensor.clone().detach() ref_out_8bit = quantize_dequantize_ref(ref_input_8bit, 8, groups) @@ -52,7 +52,7 @@ def test_quant_dequant(tensor_shape, groups): ds_out_8bit = run_quant_dequant(ds_input_8bit, groups, 8) assert (allclose(ds_out_8bit, ref_out_8bit)) - # 4 bit quantization. + # 4-bit quantization. ref_input_4bit = input_tensor.clone().detach() ds_input_4bit = input_tensor.clone().detach() ref_out_4bit = quantize_dequantize_ref(ref_input_4bit, 4, groups) From aa30df4183d7bf6fc3229da87edf9beded53a446 Mon Sep 17 00:00:00 2001 From: GuanhuaWang Date: Thu, 22 Sep 2022 17:42:30 +0000 Subject: [PATCH 11/11] rerun --- tests/unit/ops/quantizer/test_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/ops/quantizer/test_quant.py b/tests/unit/ops/quantizer/test_quant.py index 02ef89b58dda..1526937dd2bc 100644 --- a/tests/unit/ops/quantizer/test_quant.py +++ b/tests/unit/ops/quantizer/test_quant.py @@ -48,7 +48,7 @@ def test_quant_dequant(tensor_shape, groups): ref_input_8bit = input_tensor.clone().detach() ds_input_8bit = input_tensor.clone().detach() ref_out_8bit = quantize_dequantize_ref(ref_input_8bit, 8, groups) - # run_quant_dequant will do quantize then dequantize and return the dequantized value. + # run_quant_dequant will do quantize then dequantize, and return the dequantized value. ds_out_8bit = run_quant_dequant(ds_input_8bit, groups, 8) assert (allclose(ds_out_8bit, ref_out_8bit))