Skip to content

Commit

Permalink
Refactor int4 weight only quantization with call to quantize
Browse files Browse the repository at this point in the history
Summary:
This is similar to #294 but applied for int4 weight only quantization

Test Plan:

unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf
elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297
elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314
elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793

integration perf test:

reference: elapsed_time:  2.5900126953125  milliseconds
after refactor: elapsed_time:  2.56680078125  milliseconds
diff: no diff

TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

Before:
After:
generated code diff:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jun 1, 2024
1 parent 55a4676 commit 1e03a8d
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 197 deletions.
116 changes: 22 additions & 94 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from torchao.quantization.subclass import (
to_laq,
LinearActQuantizedTensor,
Int8WeightOnlyQuantizedLinearWeight,
Int4WeightOnlyQuantizedLinearWeight,
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
Expand Down Expand Up @@ -138,39 +140,27 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
)

def _ref_change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for int8 weight only quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter
from torchao.quantization.subclass import Int8WeightOnlyQuantizedLinearWeight

filter_fn = kwargs.pop("filter_fn", _is_linear)
def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for weight only quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=True, **kwargs),
filter_fn,
)
filter_fn = kwargs.pop("filter_fn", _is_linear)

def _ref_change_linear_weights_to_int4_woqtensors(model, **kwargs):
"""
The deprecated implementation for int4 weight only quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter
from torchao.quantization.subclass import Int4WeightOnlyQuantizedLinearWeight
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs),
filter_fn,
)

filter_fn = kwargs.pop("filter_fn", _is_linear)
return _ref_change_linear_weights_to_woqtensors

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs),
filter_fn,
)
_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)

class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
Expand Down Expand Up @@ -512,8 +502,7 @@ def test_quantized_tensor_subclass_int4(self):
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)

res = m(*example_inputs)
ref = m_copy(*example_inputs)
Expand All @@ -534,9 +523,9 @@ def test_quantized_tensor_subclass_int8_wo(self):
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
_ref_change_linear_weights_to_int8_woqtensors(m_copy)


res = m(*example_inputs)
ref = m_copy(*example_inputs)

Expand All @@ -559,8 +548,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_dqtensors(m_copy)
_ref_change_linear_weights_to_int8_dqtensors(m_copy)

res = m(*example_inputs)
ref = m_copy(*example_inputs)
Expand All @@ -579,65 +567,5 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
# make sure it compiles
torch._export.aot_compile(m_unwrapped, example_inputs)


def _test_quantized_tensor_subclass_perf(self, api, ref_api, kwargs=None):
if kwargs is None:
kwargs = {}

m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_ref = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")

api(m, **kwargs)

# reference
ref_api(m_ref, **kwargs)

res = m(*example_inputs)
ref = m_ref(*example_inputs)

self.assertTrue(torch.equal(res, ref))

# perf comparison
from torchao.utils import benchmark_model
# warmup
WARMUP = 5
RUNS = 100
input_tensor = example_inputs[0]
m = torch.compile(m, mode='max-autotune', fullgraph=True)

benchmark_model(m, WARMUP, input_tensor)
elapsed_time = benchmark_model(m, RUNS, input_tensor)

m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
benchmark_model(m_ref, WARMUP, input_tensor)
ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor)

print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
self.assertTrue(elapsed_time < 1.05 * ref_elapsed_time)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation")
def test_quantized_tensor_subclass_int8_dyn_quant_perf(self):
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
self._test_quantized_tensor_subclass_perf(change_linear_weights_to_int8_dqtensors, _ref_change_linear_weights_to_int8_dqtensors)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 weight only quant implementation")
def test_quantized_tensor_subclass_int8_wo_quant_perf(self):
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
self._test_quantized_tensor_subclass_perf(change_linear_weights_to_int8_woqtensors, _ref_change_linear_weights_to_int8_woqtensors)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int4 weight only quant implementation")
def test_quantized_tensor_subclass_int4_wo_quant_perf(self):
kwargs = {"groupsize": 32}
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
self._test_quantized_tensor_subclass_perf(change_linear_weights_to_int4_woqtensors, _ref_change_linear_weights_to_int4_woqtensors, kwargs)

if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 1e03a8d

Please sign in to comment.