Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add quant unit test #2315

Merged
merged 18 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @jeffra @samyam @tjruwase @ShadenSmith @conglongli @awan-10 @cli99 @eltonzheng @minjiaz @RezaYazdaniAminabadi @duli2012 @mrwyattii @yaozhewei @arashb @xiaoxiawu-microsoft @samadejacobs @cmikeh2
* @jeffra @samyam @tjruwase @ShadenSmith @conglongli @awan-10 @cli99 @eltonzheng @minjiaz @RezaYazdaniAminabadi @duli2012 @mrwyattii @yaozhewei @arashb @xiaoxiawu-microsoft @samadejacobs @cmikeh2 @GuanhuaWang
3 changes: 3 additions & 0 deletions op_builder/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ def sources(self):

def include_paths(self):
return ['csrc/includes']

def extra_ldflags(self):
return ['-lcurand']
56 changes: 56 additions & 0 deletions tests/unit/ops/quantizer/test_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import pytest
from deepspeed.ops import op_builder

quantizer_cuda_module = None


def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (2e-1, 5e-2), torch.float16: (2e-1, 5e-2)}[x.dtype]
return torch.allclose(x, y, rtol=rtol, atol=atol)


def quantize_dequantize_ref(inputs, bit, num_groups=1):
# quantize
q_range = 2**bit
input_flat = inputs.float().reshape(num_groups, -1).contiguous()
input_flat = torch.nan_to_num(input_flat, nan=0.0)
input_min = input_flat.amin(-1, keepdim=True)
input_max = input_flat.amax(-1, keepdim=True)

scale = q_range / (2 * torch.max(input_min.abs(), input_max.abs()))
input_flat = (input_flat * scale).round().clamp(-q_range // 2, q_range // 2 - 1)
# dequantize
dequant_flat = torch.t(input_flat.to(torch.int8)) / scale.view(-1).to(torch.float16)
return torch.t(dequant_flat).reshape(inputs.shape)


def run_quant_dequant(inputs, groups, bits):
global quantizer_cuda_module

if quantizer_cuda_module is None:
quantizer_cuda_module = op_builder.QuantizerBuilder().load()
return quantizer_cuda_module.ds_quantize_fp16(inputs, groups, bits)


@pytest.mark.inference
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you refactor this so that the shapes come through as parameters for the test? This will help make it clearer if one of the conditions fails if one does, which will help with debugging.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cmikeh2 done.

@pytest.mark.parametrize("tensor_shape", [(8, 8), (128, 256)])
def test_quant_dequant(tensor_shape):
input_tensor = torch.rand((tensor_shape), dtype=torch.float16).cuda()

# test 8bit quant/dequant on tensor partitioned in 1 group.
ref_input_8bit_1group = input_tensor.clone().detach()
ds_input_8bit_1group = input_tensor.clone().detach()
ref_out_8bit_1group = quantize_dequantize_ref(ref_input_8bit_1group, 8)
# run_quant_dequant will do quantize then dequantize and return the dequantized value.
ds_out_8bit_1group = run_quant_dequant(ds_input_8bit_1group, 1, 8)
assert (allclose(ds_out_8bit_1group, ref_out_8bit_1group))

# test 4bit quant/dequant on tensor partitioned into 16 groups.
# Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG.
ref_input_4bit_16group = input_tensor.clone().detach()
ds_input_4bit_16group = input_tensor.clone().detach()
ref_out_4bit_16group = quantize_dequantize_ref(ref_input_4bit_16group, 4, 16)
ds_out_4bit_16group = run_quant_dequant(ds_input_4bit_16group, 16, 4)
assert (allclose(ds_out_4bit_16group, ref_out_4bit_16group))