Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mem access for quantize kernel #2331

Merged
merged 14 commits into from
Sep 22, 2022
173 changes: 79 additions & 94 deletions csrc/quantization/quantizer.cu
Original file line number Diff line number Diff line change
@@ -1,95 +1,84 @@
#include <math.h>
#include "custom_cuda_layers.h"
#include "memory_access_utils.h"

namespace cg = cooperative_groups;

__global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)

cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
cg::thread_block b = cg::this_thread_block(); // tb
cg::thread_block_tile<32> g =
cg::tiled_partition<32>(b); // warp, 32 not optimal for AMD which should be 64.

int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;

float2* vals_cast = reinterpret_cast<float2*>(vals);
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(__half);

float2 data[MAX_REG];
__half data[vals_per_access];

int group_id = blockIdx.x;

{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;

while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);

if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]);
if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]);
if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]);
if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]);

group_index += blockDim.x;
reg_count++;
}
int thread_index = id * vals_per_access;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
for (int thread_index = id * vals_per_access; thread_index < group_size;
thread_index += blockDim.x * vals_per_access) {
mem_access::load_global<granularity>(data, vals + offset + thread_index);

#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
for (int i = 0; i < vals_per_access; i++) {
if (abs((float)data[i]) > max) max = abs((float)data[i]);
}
__shared__ float partialMax[WARP_SIZE];

if (lane == 0) partialMax[gid] = max;

b.sync();

if (lane < warp_num) max = partialMax[lane];
}

#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];

max = g.shfl(max, 0);
if (lane == 0) partialMax[gid] = max;

float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
b.sync();

float2 q_data_int[2];
if (lane < warp_num) max = partialMax[lane];

q_data_int[0].x = roundf(q_data[0].x * q_scale);
q_data_int[0].y = roundf(q_data[0].y * q_scale);
q_data_int[1].x = roundf(q_data[1].x * q_scale);
q_data_int[1].y = roundf(q_data[1].y * q_scale);
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}

q_data_int[0].x *= q_scale_inv;
q_data_int[0].y *= q_scale_inv;
q_data_int[1].x *= q_scale_inv;
q_data_int[1].y *= q_scale_inv;
max = g.shfl(max, 0);

data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
float q_scale = (float)(1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
int q_range_max = (1 << (num_bits - 1)) - 1;
int q_range_min = -(1 << (num_bits - 1));

vals_cast[offset + group_index] = data[i];
}
}
for (int thread_index = id * vals_per_access; thread_index < group_size;
thread_index += blockDim.x * vals_per_access) {
mem_access::load_global<granularity>(data, vals + offset + thread_index);
#pragma unroll
for (int j = 0; j < vals_per_access; j++) {
float q_data;
q_data = __half2float(data[j]);
q_data = __float2int_rn(q_data * q_scale);
q_data = q_data > (q_range_max) ? (q_range_max)
: (q_data < (q_range_min) ? (q_range_min) : q_data);
data[j] = __float2half_rn(q_data * q_scale_inv);
}
mem_access::store_global<granularity>(vals + offset + thread_index, data);
}

#endif
}

Expand All @@ -103,31 +92,31 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits)
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;

float4* vals_cast = reinterpret_cast<float4*>(vals);
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(float);

float4 data[MAX_REG];
float data[vals_per_access];

int bid = blockIdx.x;

int group_index = bid * group_size + id;
int thread_index = id * vals_per_access;

int reg_count = 0;

float max = -10000.0;
int offset = bid * group_size;

while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
float max = -10000.0;

if (abs(data_reg.x) > max) max = abs(data_reg.x);
if (abs(data_reg.y) > max) max = abs(data_reg.y);
if (abs(data_reg.z) > max) max = abs(data_reg.z);
if (abs(data_reg.w) > max) max = abs(data_reg.w);
for (int thread_index = id * vals_per_access; thread_index < group_size;
thread_index += blockDim.x * vals_per_access) {
mem_access::load_global<granularity>(data, vals + offset + thread_index);

group_index += blockDim.x;
id += blockDim.x;
reg_count++;
#pragma unroll
for (int i = 0; i < vals_per_access; i++) {
if (abs(data[i]) > max) max = abs(data[i]);
}
}
id = threadIdx.x;

#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
Expand All @@ -153,25 +142,22 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits)

float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];

float4 q_data_int;
q_data_int.x = roundf(q_data.x * q_scale);
q_data_int.y = roundf(q_data.y * q_scale);
q_data_int.w = roundf(q_data.w * q_scale);
q_data_int.z = roundf(q_data.z * q_scale);
int q_range_max = (1 << (num_bits - 1)) - 1;
int q_range_min = -(1 << (num_bits - 1));

q_data.x = q_data_int.x * q_scale_inv;
q_data.y = q_data_int.y * q_scale_inv;
q_data.w = q_data_int.w * q_scale_inv;
q_data.z = q_data_int.z * q_scale_inv;

vals_cast[group_index + bid * group_size] = q_data;
}
for (int thread_index = id * vals_per_access; thread_index < group_size;
thread_index += blockDim.x * vals_per_access) {
mem_access::load_global<granularity>(data, vals + offset + thread_index);
#pragma unroll
for (int j = 0; j < vals_per_access; j++) {
float q_data;
q_data = __float2int_rn(data[j] * q_scale);
q_data = q_data > (q_range_max) ? (q_range_max)
: (q_data < (q_range_min) ? (q_range_min) : q_data);
data[j] = roundf(q_data * q_scale_inv);
}
mem_access::store_global<granularity>(vals + offset + thread_index, data);
}
}

Expand All @@ -185,8 +171,7 @@ void launch_quantize_kernel(T* vals,
dim3 grid_dim(group_num);
dim3 block_dim(1024);

quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, num_bits);
quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(vals, total_count / group_num, num_bits);
}

template void launch_quantize_kernel(float* vals,
Expand Down
42 changes: 23 additions & 19 deletions tests/unit/ops/quantizer/test_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (2e-1, 5e-2), torch.float16: (2e-1, 5e-2)}[x.dtype]
rtol, atol = {torch.float32: (2e-2, 5e-3), torch.float16: (2e-2, 5e-3)}[x.dtype]
return torch.allclose(x, y, rtol=rtol, atol=atol)


Expand All @@ -19,7 +19,7 @@ def quantize_dequantize_ref(inputs, bit, num_groups=1):
input_min = input_flat.amin(-1, keepdim=True)
input_max = input_flat.amax(-1, keepdim=True)

scale = q_range / (2 * torch.max(input_min.abs(), input_max.abs()))
scale = q_range / (2 * torch.max(input_min.abs(), input_max.abs() + 1e-5))
input_flat = (input_flat * scale).round().clamp(-q_range // 2, q_range // 2 - 1)
# dequantize
dequant_flat = torch.t(input_flat.to(torch.int8)) / scale.view(-1).to(torch.float16)
Expand All @@ -35,22 +35,26 @@ def run_quant_dequant(inputs, groups, bits):


@pytest.mark.inference
@pytest.mark.parametrize("tensor_shape", [(8, 8), (128, 256)])
def test_quant_dequant(tensor_shape):
@pytest.mark.parametrize("tensor_shape", [(16, 4096), (128, 256)])
# Test with two tensor shapes as (16, 4096) and (128, 256).
@pytest.mark.parametrize("groups", [1, 16])
# Test with number of quant groups as 1 and 16.
# Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG.
def test_quant_dequant(tensor_shape, groups):

input_tensor = torch.rand((tensor_shape), dtype=torch.float16).cuda()

# test 8bit quant/dequant on tensor partitioned in 1 group.
ref_input_8bit_1group = input_tensor.clone().detach()
ds_input_8bit_1group = input_tensor.clone().detach()
ref_out_8bit_1group = quantize_dequantize_ref(ref_input_8bit_1group, 8)
# run_quant_dequant will do quantize then dequantize and return the dequantized value.
ds_out_8bit_1group = run_quant_dequant(ds_input_8bit_1group, 1, 8)
assert (allclose(ds_out_8bit_1group, ref_out_8bit_1group))

# test 4bit quant/dequant on tensor partitioned into 16 groups.
# Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG.
ref_input_4bit_16group = input_tensor.clone().detach()
ds_input_4bit_16group = input_tensor.clone().detach()
ref_out_4bit_16group = quantize_dequantize_ref(ref_input_4bit_16group, 4, 16)
ds_out_4bit_16group = run_quant_dequant(ds_input_4bit_16group, 16, 4)
assert (allclose(ds_out_4bit_16group, ref_out_4bit_16group))
# 8-bit quantization.
ref_input_8bit = input_tensor.clone().detach()
ds_input_8bit = input_tensor.clone().detach()
ref_out_8bit = quantize_dequantize_ref(ref_input_8bit, 8, groups)
# run_quant_dequant will do quantize then dequantize, and return the dequantized value.
ds_out_8bit = run_quant_dequant(ds_input_8bit, groups, 8)
assert (allclose(ds_out_8bit, ref_out_8bit))

# 4-bit quantization.
ref_input_4bit = input_tensor.clone().detach()
ds_input_4bit = input_tensor.clone().detach()
ref_out_4bit = quantize_dequantize_ref(ref_input_4bit, 4, groups)
ds_out_4bit = run_quant_dequant(ds_input_4bit, groups, 4)
assert (allclose(ds_out_4bit, ref_out_4bit))