-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add quant unit test #2315
Merged
Merged
add quant unit test #2315
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
48bd182
add quant unit test
GuanhuaWang 7f56db6
add codeowner
GuanhuaWang e696ed3
format fix
GuanhuaWang b5a5508
fix undefined symbol: curandSetPseudoRandomGeneratorSeed
GuanhuaWang 0f5dc1b
modify ref fn name and add comment
GuanhuaWang dfd3ee7
add comments
GuanhuaWang 5cd45c8
add 4bit quant 16groups
GuanhuaWang 325b86a
fix
GuanhuaWang caa716f
modify groups in ref code
GuanhuaWang 9410fcd
parameterize tensor shape
GuanhuaWang 1ebd819
single param
GuanhuaWang 9f6fa41
detach tensor
GuanhuaWang 798c2f0
Merge branch 'master' into guanhua/quant-new-test
awan-10 6bc80bb
remove -lcurand flag
GuanhuaWang 2cf46f1
add back -lcurand flag
GuanhuaWang 4f0b71e
Merge branch 'master' into guanhua/quant-new-test
GuanhuaWang c5d2173
Merge branch 'master' into guanhua/quant-new-test
GuanhuaWang 2a5b3ad
format
GuanhuaWang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
* @jeffra @samyam @tjruwase @ShadenSmith @conglongli @awan-10 @cli99 @eltonzheng @minjiaz @RezaYazdaniAminabadi @duli2012 @mrwyattii @yaozhewei @arashb @xiaoxiawu-microsoft @samadejacobs @cmikeh2 | ||
* @jeffra @samyam @tjruwase @ShadenSmith @conglongli @awan-10 @cli99 @eltonzheng @minjiaz @RezaYazdaniAminabadi @duli2012 @mrwyattii @yaozhewei @arashb @xiaoxiawu-microsoft @samadejacobs @cmikeh2 @GuanhuaWang |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
import pytest | ||
from deepspeed.ops import op_builder | ||
|
||
quantizer_cuda_module = None | ||
|
||
|
||
def allclose(x, y): | ||
assert x.dtype == y.dtype | ||
rtol, atol = {torch.float32: (2e-1, 5e-2), torch.float16: (2e-1, 5e-2)}[x.dtype] | ||
return torch.allclose(x, y, rtol=rtol, atol=atol) | ||
|
||
|
||
def quantize_dequantize_ref(inputs, bit, num_groups=1): | ||
# quantize | ||
q_range = 2**bit | ||
input_flat = inputs.float().reshape(num_groups, -1).contiguous() | ||
input_flat = torch.nan_to_num(input_flat, nan=0.0) | ||
input_min = input_flat.amin(-1, keepdim=True) | ||
input_max = input_flat.amax(-1, keepdim=True) | ||
|
||
scale = q_range / (2 * torch.max(input_min.abs(), input_max.abs())) | ||
input_flat = (input_flat * scale).round().clamp(-q_range // 2, q_range // 2 - 1) | ||
# dequantize | ||
dequant_flat = torch.t(input_flat.to(torch.int8)) / scale.view(-1).to(torch.float16) | ||
return torch.t(dequant_flat).reshape(inputs.shape) | ||
|
||
|
||
def run_quant_dequant(inputs, groups, bits): | ||
global quantizer_cuda_module | ||
|
||
if quantizer_cuda_module is None: | ||
quantizer_cuda_module = op_builder.QuantizerBuilder().load() | ||
return quantizer_cuda_module.ds_quantize_fp16(inputs, groups, bits) | ||
|
||
|
||
@pytest.mark.inference | ||
@pytest.mark.parametrize("tensor_shape", [(8, 8), (128, 256)]) | ||
def test_quant_dequant(tensor_shape): | ||
input_tensor = torch.rand((tensor_shape), dtype=torch.float16).cuda() | ||
|
||
# test 8bit quant/dequant on tensor partitioned in 1 group. | ||
ref_input_8bit_1group = input_tensor.clone().detach() | ||
ds_input_8bit_1group = input_tensor.clone().detach() | ||
ref_out_8bit_1group = quantize_dequantize_ref(ref_input_8bit_1group, 8) | ||
# run_quant_dequant will do quantize then dequantize and return the dequantized value. | ||
ds_out_8bit_1group = run_quant_dequant(ds_input_8bit_1group, 1, 8) | ||
assert (allclose(ds_out_8bit_1group, ref_out_8bit_1group)) | ||
|
||
# test 4bit quant/dequant on tensor partitioned into 16 groups. | ||
# Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG. | ||
ref_input_4bit_16group = input_tensor.clone().detach() | ||
ds_input_4bit_16group = input_tensor.clone().detach() | ||
ref_out_4bit_16group = quantize_dequantize_ref(ref_input_4bit_16group, 4, 16) | ||
ds_out_4bit_16group = run_quant_dequant(ds_input_4bit_16group, 16, 4) | ||
assert (allclose(ds_out_4bit_16group, ref_out_4bit_16group)) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you refactor this so that the shapes come through as parameters for the test? This will help make it clearer if one of the conditions fails if one does, which will help with debugging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cmikeh2 done.