From a7670be5870029e5746d6e39154ea09111d5add4 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Wed, 27 Mar 2024 21:07:40 -0700 Subject: [PATCH 01/21] Use .to() instead of get_original_weight in linear_nf4 backward (#90) Co-authored-by: cpuhrsch --- torchao/dtypes/nf4tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 34ad778027..ea45a6c0de 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -569,9 +569,9 @@ def forward(ctx, input: torch.Tensor, weight: NF4Tensor): # inconsistently. def backward(ctx, grad_output): - """The nf4 weight will never require grad so we can just return the grad_output @ weight.get_original_weight()""" + """The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)""" weight: NF4Tensor = ctx.nf4_weight - return grad_output @ weight.get_original_weight(), None + return grad_output @ weight.to(grad_output.dtype), None def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: From ec08d7189c895fe64e74e6f5f7090181426e4b74 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 29 Mar 2024 16:13:43 -0700 Subject: [PATCH 02/21] Change quantization version check to use 2.3.0.dev (#99) Summary: this is so that it works with executorch, which depends on torch 2.3.0 Test Plan: CI Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 21 ++++++++++++--------- test/quantization/test_quant_api.py | 4 ++-- torchao/quantization/quant_api.py | 8 ++++---- torchao/quantization/quant_primitives.py | 8 ++++---- torchao/quantization/utils.py | 8 ++++---- 5 files changed, 26 insertions(+), 23 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index fd3a3311df..6873f7fbac 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -57,7 +57,7 @@ from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os from parameterized import parameterized -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 torch.manual_seed(0) config.cache_size_limit = 100 @@ -836,7 +836,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if dtype != torch.bfloat16: self.skipTest("Currently only supports bfloat16.") @@ -846,7 +846,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest("Currently only supports bfloat16.") @@ -902,13 +902,14 @@ def test_int8_dynamic_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skip("flaky test, will fix in another PR") def test_int8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_int4_weight_only_quant_subclass(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -918,7 +919,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -975,13 +976,14 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skip("flaky test, will fix in another PR") def test_int8_weight_only_quant_subclass_api(self, device, dtype): self._test_lin_weight_subclass_api_impl( change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_int4_weight_only_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -995,7 +997,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1155,11 +1157,12 @@ def test_save_load_dqtensors(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @torch.no_grad() + @unittest.skip("flaky test, will fix in another PR") def test_save_load_int8woqtensors(self, device, dtype): self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") @torch.no_grad() def test_save_load_int4woqtensors(self, device, dtype): if dtype != torch.bfloat16: @@ -1169,7 +1172,7 @@ def test_save_load_int4woqtensors(self, device, dtype): class TorchCompileUnitTest(unittest.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "fullgraph requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "fullgraph requires torch nightly.") def test_fullgraph(self): lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index cb5b8344ca..436dba0185 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -25,7 +25,7 @@ TwoStepQuantizer, ) from torchao.quantization.utils import ( - TORCH_VERSION_AFTER_2_4, + TORCH_VERSION_AFTER_2_3, ) from pathlib import Path from sentencepiece import SentencePieceProcessor @@ -136,7 +136,7 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): compiled = m(*example_inputs) torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") def test_8da4w_quantizer(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.quant_api import Int8DynActInt4WeightLinear diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4194ceb9be..fb83f90b22 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -23,7 +23,7 @@ import torch.nn.functional as F from .dynamic_quant import DynamicallyPerAxisQuantizedLinear -from .utils import TORCH_VERSION_AFTER_2_4 +from .utils import TORCH_VERSION_AFTER_2_3 from .subclass import ( Int4WeightOnlyQuantizedLinearWeight, @@ -33,7 +33,7 @@ ) from .weight_only import WeightOnlyInt8QuantLinear -_AFTER_TORCH_2_4_ONLY = [ +_AFTER_TORCH_2_3_ONLY = [ "Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightGPTQQuantizer", ] @@ -48,7 +48,7 @@ "swap_conv2d_1x1_to_linear", "Quantizer", "TwoStepQuantizer", -] + (_AFTER_TORCH_2_4_ONLY if TORCH_VERSION_AFTER_2_4 else []) +] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else []) ############################# Unified Quantization APIs ############################## @@ -224,7 +224,7 @@ def replace_conv2d_1x1(conv): ) -if TORCH_VERSION_AFTER_2_4: +if TORCH_VERSION_AFTER_2_3: from .quant_primitives import ( get_group_qparams_symmetric, group_quantize_tensor_symmetric, diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index c8ff618154..98afa9a19c 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -11,10 +11,10 @@ from torch.library import impl from torchao.kernel.intmm import int_scaled_matmul -from .utils import TORCH_VERSION_AFTER_2_4 +from .utils import TORCH_VERSION_AFTER_2_3 -_AFTER_TORCH_2_4_ONLY = [ +_AFTER_TORCH_2_3_ONLY = [ "per_token_dynamic_quant", "get_group_qparams_symmetric", ] @@ -38,7 +38,7 @@ "groupwise_affine_quantize_tensor", "groupwise_affine_dequantize_tensor", # TODO: need to clean up above functions -] + (_AFTER_TORCH_2_4_ONLY if TORCH_VERSION_AFTER_2_4 else []) +] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else []) def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: @@ -571,7 +571,7 @@ def pack_scales_and_zeros(scales, zeros, precision=torch.float16): ) -if TORCH_VERSION_AFTER_2_4: +if TORCH_VERSION_AFTER_2_3: def group_quantize_tensor_symmetric( w, n_bit=4, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index e20ed6cfc5..1bca949b41 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -14,7 +14,7 @@ "compute_error", "_apply_logging_hook", "get_model_size_in_bytes", - "TORCH_VERSION_AFTER_2_4", + "TORCH_VERSION_AFTER_2_3", ] @@ -96,7 +96,7 @@ def get_model_size_in_bytes(model): return s -if version.parse(torch.__version__) >= version.parse("2.4.0.dev"): - TORCH_VERSION_AFTER_2_4 = True +if version.parse(torch.__version__) >= version.parse("2.3.0.dev"): + TORCH_VERSION_AFTER_2_3 = True else: - TORCH_VERSION_AFTER_2_4 = False + TORCH_VERSION_AFTER_2_3 = False From 5420089c79e94deb5f6a807c0b339025f59a359d Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Fri, 29 Mar 2024 19:55:47 -0700 Subject: [PATCH 03/21] Expand CI coverage to 2.2.2, 2.3rc and nightly (#96) --- .github/workflows/regression_test.yml | 25 +++++-- torchao/kernel/intmm.py | 89 ++++++++++++++---------- torchao/quantization/quant_primitives.py | 58 +-------------- torchao/quantization/utils.py | 9 +++ 4 files changed, 82 insertions(+), 99 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index a1bee9a23b..392ee1947c 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -9,7 +9,7 @@ on: - main jobs: - test: + test-cuda-2-2-2: runs-on: 4-core-ubuntu-gpu-t4 steps: - uses: actions/checkout@v2 @@ -22,10 +22,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install torch + pip install torch==2.2.2 pip install -r requirements.txt pip install -r dev-requirements.txt - - name: Install package run: | @@ -35,7 +34,24 @@ jobs: run: | pytest test --verbose -s -x - test-nightly: + test-cuda-2-3-rc: + runs-on: 4-core-ubuntu-gpu-t4 + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + pip install -r requirements.txt + pip install -r dev-requirements.txt + + test-cuda-nightly: runs-on: 4-core-ubuntu-gpu-t4 steps: - uses: actions/checkout@v2 @@ -103,7 +119,6 @@ jobs: pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu pip install -r requirements.txt pip install -r dev-requirements.txt - - name: Install package run: | diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index bdd583520e..d2afa66a0a 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -2,52 +2,65 @@ import os import torch -from torch._dynamo import is_compiling as dynamo_is_compiling -from torch._higher_order_ops.out_dtype import out_dtype +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_2 try: - from torchao.kernel import intmm_triton + # Only works for torch2.2 or newer. + if TORCH_VERSION_AFTER_2_2: + from torchao.kernel import intmm_triton + else: + intmm_triton = None except ImportError: + # On cpu-only builds might not be available. intmm_triton = None AUTOTUNER_ENABLE = bool(int(os.getenv("TORCHAO_AUTOTUNER_ENABLE", 0))) -def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: - # torch.compile path - if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - - # error checking for cublas path - assert ( - mat2.device == input.device - ), f"need both tensors to be on the same device but got {mat2.device} and {input.device}" - device_cpu = "cpu" in [mat2.device.type, input.device.type] - # with input.shape = [i,j] and mat2.shape = [j,k] - i_is_strictly_greater_than_16 = input.shape[0] > 16 - j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) - k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) - bad_dimensions_for_cublas = not ( - i_is_strictly_greater_than_16 - and j_is_nonzero_multiple_of_8 - and k_is_nonzero_multiple_of_8 - ) - - if device_cpu or bad_dimensions_for_cublas: - # fallback path - return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( - input.device.type +# torch._int_mm doesn't exist before 2.2 +if TORCH_VERSION_AFTER_2_2: + from torch._dynamo import is_compiling as dynamo_is_compiling + from torch._higher_order_ops.out_dtype import out_dtype + def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: + # torch.compile path + if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + + # error checking for cublas path + assert ( + mat2.device == input.device + ), f"need both tensors to be on the same device but got {mat2.device} and {input.device}" + device_cpu = "cpu" in [mat2.device.type, input.device.type] + # with input.shape = [i,j] and mat2.shape = [j,k] + i_is_strictly_greater_than_16 = input.shape[0] > 16 + j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) + k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) + bad_dimensions_for_cublas = not ( + i_is_strictly_greater_than_16 + and j_is_nonzero_multiple_of_8 + and k_is_nonzero_multiple_of_8 ) - - # cublas paths - if not mat2.is_contiguous(): # silently gives incorrect result without this - mat2 = mat2.contiguous() - if (not input.is_contiguous()) and ( - input.shape[0] % 8 != 0 - ): # gives cryptic error without this - input = ( - input.contiguous() - ) # (it seems the transpose makes cublas check the above j constraint on i) - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + + if device_cpu or bad_dimensions_for_cublas: + # fallback path + return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( + input.device.type + ) + + # cublas paths + if not mat2.is_contiguous(): # silently gives incorrect result without this + mat2 = mat2.contiguous() + if (not input.is_contiguous()) and ( + input.shape[0] % 8 != 0 + ): # gives cryptic error without this + input = ( + input.contiguous() + ) # (it seems the transpose makes cublas check the above j constraint on i) + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) +else: + def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: + # We can improve on this by writing Triton code that works for older versions of Triton + # that ship with 2.1 or 2.0. + return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) def int_matmul(a, b): diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 98afa9a19c..3e05dd3c42 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -11,6 +11,8 @@ from torch.library import impl from torchao.kernel.intmm import int_scaled_matmul +from .utils import TORCH_VERSION_AFTER_2_4 +from torchao.kernel.intmm import safe_int_mm from .utils import TORCH_VERSION_AFTER_2_3 @@ -40,64 +42,8 @@ # TODO: need to clean up above functions ] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else []) - -def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: - r""" - This function wraps torch._int_mm and avoids several undesirable behaviors of the function for certain inputs while still - returning correct results and being torch.compiled in a performant way. - - Assumes both tensors have dimension of 2. - - Note: no error checking for torch.compiled path, if input.shape = [i, j] and j<=16 then the triton kernel - will error. - - Args: - input (Tensor, int8): the first tensor to be multiplied - mat2 (Tensor, int8): the second tensor to be multiplied - - Return: - out (Tensor, int32): the result of the matmul with device matching that of the inputs - """ - # torch.compile path - if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - - # error checking for cublas path - assert ( - mat2.device == input.device - ), f"need both tensors to be on the same device but got {mat2.device} and {input.device}" - device_cpu = "cpu" in [mat2.device.type, input.device.type] - # with input.shape = [i,j] and mat2.shape = [j,k] - i_is_strictly_greater_than_16 = input.shape[0] > 16 - j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) - k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) - bad_dimensions_for_cublas = not ( - i_is_strictly_greater_than_16 - and j_is_nonzero_multiple_of_8 - and k_is_nonzero_multiple_of_8 - ) - - if device_cpu or bad_dimensions_for_cublas: - # fallback path - return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( - input.device.type - ) - - # cublas paths - if not mat2.is_contiguous(): # silently gives incorrect result without this - mat2 = mat2.contiguous() - if (not input.is_contiguous()) and ( - input.shape[0] % 8 != 0 - ): # gives cryptic error without this - input = ( - input.contiguous() - ) # (it seems the transpose makes cublas check the above j constraint on i) - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - - # copy-pasta of https://www.internalfb.com/intern/anp/view/?id=3350736 - def dynamically_quantize_per_tensor( x, quant_min, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 1bca949b41..1f6b3a9bcf 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -95,8 +95,17 @@ def get_model_size_in_bytes(model): s += b.nelement() * b.element_size() return s +if version.parse(torch.__version__) >= version.parse("2.4.0.dev"): + TORCH_VERSION_AFTER_2_4 = True +else: + TORCH_VERSION_AFTER_2_4 = False if version.parse(torch.__version__) >= version.parse("2.3.0.dev"): TORCH_VERSION_AFTER_2_3 = True else: TORCH_VERSION_AFTER_2_3 = False + +if version.parse(torch.__version__) >= version.parse("2.2.0.dev"): + TORCH_VERSION_AFTER_2_2 = True +else: + TORCH_VERSION_AFTER_2_2 = False From 293ae7b79785bee872209667cd8703abaebcff08 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Sat, 30 Mar 2024 02:34:12 -0400 Subject: [PATCH 04/21] Refactor GPTQ Quantizer, remove lm_eval (#104) Summary: refactor GPTQ code, remove lm_eval dependency of gptq, remove model dependency of InputRecorder made GPTQ work with gpt-fast. also fixed model so its kv_cache doesn't break gptq. Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d6eb81037e37ba556bd37a3b248c51482be7c68d Pull Request resolved: https://github.com/pytorch-labs/ao/pull/103 --- test/quantization/model.py | 16 +- test/quantization/test_quant_api.py | 81 +- torchao/quantization/GPTQ.py | 938 +++++++++++++++++++---- torchao/quantization/__init__.py | 1 + torchao/quantization/quant_api.py | 644 +--------------- torchao/quantization/quant_primitives.py | 1 - torchao/quantization/unified.py | 29 + 7 files changed, 896 insertions(+), 814 deletions(-) create mode 100644 torchao/quantization/unified.py diff --git a/test/quantization/model.py b/test/quantization/model.py index 17a59e5bb0..b9705313e6 100644 --- a/test/quantization/model.py +++ b/test/quantization/model.py @@ -11,6 +11,16 @@ from torch import Tensor from torch.nn import functional as F +def prepare_inputs_for_model(inps): + # setup inputs in correct format + max_new_tokens = 1 + T = inps.size(0) + T_new = T + max_new_tokens + seq = torch.empty(T_new, dtype=inps.dtype, device=inps.device) + seq[:T] = inps + input_pos = torch.arange(0, T, device=inps.device) + x = seq.index_select(0, input_pos).view(1, -1) + return (x, input_pos) def find_multiple(n: int, k: int) -> int: if n % k == 0: @@ -76,10 +86,8 @@ def update(self, input_pos, k_val, v_val): # input_pos: [S], k_val: [B, H, S, D] assert input_pos.shape[0] == k_val.shape[2] - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val + k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val) + v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val) return k_out, v_out diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 436dba0185..156fdbd78a 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -29,7 +29,7 @@ ) from pathlib import Path from sentencepiece import SentencePieceProcessor -from model import Transformer +from model import Transformer, prepare_inputs_for_model def dynamic_quant(model, example_inputs): @@ -139,9 +139,9 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") def test_8da4w_quantizer(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - from torchao.quantization.quant_api import Int8DynActInt4WeightLinear + from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear - quantizer = Int8DynActInt4WeightQuantizer(group_size=32) + quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) m = M().eval() example_inputs = m.example_inputs() m = quantizer.quantize(m) @@ -151,7 +151,7 @@ def test_8da4w_quantizer(self): @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_gptq_quantizer(self): - from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer + from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 device = "cpu" @@ -169,20 +169,83 @@ def test_gptq_quantizer(self): percdamp = 0.01 groupsize = 128 calibration_tasks = ["wikitext"] - calibration_limit = 5 + calibration_limit = 1 calibration_seq_length = 100 + input_prep_func = prepare_inputs_for_model pad_calibration_inputs = False - quantizer = Int8DynActInt4WeightGPTQQuantizer( + + inputs = InputRecorder( tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + ).record_inputs( + calibration_tasks, + calibration_limit, + ).get_inputs() + + quantizer = Int8DynActInt4WeightGPTQQuantizer( blocksize, percdamp, groupsize, - calibration_tasks, - calibration_limit, + ) + model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) + model = quantizer.quantize(model, inputs) + compiled = torch.compile(model, mode="max-autotune") + with torch.no_grad(): + compiled(inputs[0].values[0], inputs[1].values[0]) + + @unittest.skip("skipping until we get checkpoints for gpt-fast") + def test_gptq_quantizer_gpt_fast(self): + from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder + # should be similar to TorchCompileDynamicQuantizer + precision = torch.bfloat16 + device = "cuda" + checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") + model = Transformer.from_name(checkpoint_path.parent.name) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor( # pyre-ignore[28] + model_file=str(tokenizer_path) + ) + blocksize = 128 + percdamp = 0.01 + groupsize = 128 + calibration_tasks = ["wikitext"] + calibration_limit = 1 + calibration_seq_length = 100 + input_prep_func = prepare_inputs_for_model + pad_calibration_inputs = False + + inputs = InputRecorder( + tokenizer, calibration_seq_length, + input_prep_func, pad_calibration_inputs, + model.config.vocab_size, + ).record_inputs( + calibration_tasks, + calibration_limit, + ).get_inputs() + + quantizer = Int8DynActInt4WeightGPTQQuantizer( + blocksize, + percdamp, + groupsize, + _is_gpt_fast=True, + _use_cuda=True, ) - model = quantizer.quantize(model) + + model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) + + model = quantizer.quantize(model, inputs) + compiled = torch.compile(model, mode="max-autotune") + with torch.no_grad(): + compiled(inputs[0].values[0], inputs[1].values[0]) if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index a82edca528..304a84ac56 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -9,7 +9,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Optional +from typing import Optional, List import torch @@ -20,24 +20,14 @@ # from model import Transformer # pyre-ignore[21] from torch.utils._pytree import tree_flatten, tree_unflatten +from .utils import TORCH_VERSION_AFTER_2_4 +from typing import Any, Dict, Tuple, Optional +from .unified import Quantizer +from functools import reduce +from math import gcd aten = torch.ops.aten -## generate.py ## - - -def encode_tokens(tokenizer, string, bos=True, device="cuda"): - - tokens = tokenizer.encode(string) - if bos: - tokens = [tokenizer.bos_id()] + tokens - return torch.tensor(tokens, dtype=torch.int, device=device) - - -def model_forward(model, x, input_pos): - return model(x, input_pos) - - ## eval.py ## try: @@ -61,68 +51,51 @@ def model_forward(model, x, input_pos): else: logging.info("lm_eval is not installed, GPTQ may not be usable") +if lm_eval_available: + class InputRecorder(eval_wrapper): + """ + This is a fake evaluation wrapper from the lm_eval library that just records the inputs + so that they can be used in calibration. -def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( - model: torch.nn.Module, - prompt: torch.Tensor, - max_new_tokens: int, - max_seq_length: Optional[int] = None, - block_size: int = 2048, -): - """ - Sets up model cache and does some bookkeeping calculations for prompt, input_pos and max_seq_length - that are needed for prefill or model_forward - - Args: - model (torch.nn.Module): The model whose cache gets set up - prompt (torch.Tensor): Tensor of shape (T) with indices of the prompt sequence. - max_new_tokens (int): The desired maximum number of new tokens that can be generated. - max_seq_length (Optional[int], optional): The maximum sequence length allowed. - - Returns: - seq (torch.Tensor): prompt but padded with zeros to size max_seq_length - input_pos (torch.Tensor): tensor of integers in increasing order - max_seq_length (int): The maximum sequence length allowed, updated based on other numbers - """ - T = prompt.size(0) - T_new = T + max_new_tokens - if max_seq_length is None: - max_seq_length = min(T_new, block_size) - - device, dtype = prompt.device, prompt.dtype - # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = prompt - seq = empty - input_pos = torch.arange(0, T, device=device) - - # no caches in executorch llama2 7b model? - # with torch.device(device): - # model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - - return seq, input_pos, max_seq_length + If pad_calibration_inputs is enabled, the input recorder will take + each input and pad/truncate it down to the calibration_seq_length. + (if using padding you should set the embeddings for the pad_token to 0 + in the model) + Note: after padding/truncation, input_prep_function is called to bring + it to the proper form to be inserted into a given model. -if lm_eval_available: - - class GPTFastEvalWrapper(eval_wrapper): # pyre-ignore[11] - """ - A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. + If not, it will only truncate inputs to the desired length. """ def __init__( self, - model: torch.nn.Module, tokenizer, - max_seq_length: Optional[int] = None, + calibration_seq_length, + input_prep_func=None, + pad_calibration_inputs=False, + vocab_size=32000, + pad_token=0, + device="cpu", ): super().__init__() - self._model = model - self._tokenizer = tokenizer - self._device = torch.device("cuda") + self._device = torch.device(device) + self.vocab_size = vocab_size + self._max_seq_length = calibration_seq_length + self.calibration_seq_length = calibration_seq_length + + # need to take inps and convert to corrent input + # for model + self.input_prep_func = ( + input_prep_func if input_prep_func is not None + else lambda x: x + ) + + self.pad_calibration_inputs = pad_calibration_inputs + self.pad_token = pad_token - self._max_seq_length = 2048 if max_seq_length is None else max_seq_length + self.inputs = None @property def eot_token_id(self): @@ -145,89 +118,16 @@ def device(self): return self._device def tok_encode(self, string: str, **kwargs): - encoded = encode_tokens( - self._tokenizer, string, bos=True, device=self._device - ) - # encoded is a pytorch tensor, but some internal logic in the - # eval harness expects it to be a list instead # TODO: verify this for multi-batch as well - encoded = encoded.tolist() - return encoded + tokens = self._tokenizer.encode(string) + if hasattr(self._tokenizer, "bos_id"): + tokens = [self._tokenizer.bos_id()] + tokens + return tokens def tok_decode(self, tokens): decoded = self._tokenizer.decode(tokens) return decoded - def _model_call(self, inps): - print("in model_call") - # TODO: make batches work - inps = inps.squeeze(0) - - max_new_tokens = 1 - seq, input_pos, max_seq_length = ( - setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( - self._model, - inps, - max_new_tokens, - self.max_length, - ) - ) - x = seq.index_select(0, input_pos).view(1, -1) - logits = model_forward(self._model, x, input_pos) - return logits - - def _model_generate(self, context, max_length, eos_token_id): - raise Exception("unimplemented") - - class InputRecorder(GPTFastEvalWrapper): - """ - This is a fake evaluation wrapper that just records the inputs - so that they can be used in calibration. - - If pad_calibration_inputs is enabled, the input recorder will take - each input and pad/truncate it down to the calibration_seq_length. - It will also edit the model embeddings to be zero for the 0 token used - in padding and avoid any inputs with the 0 token. - - If not, it will only truncate inputs to the desired length. - """ - - def __init__( - self, - model: torch.nn.Module, - tokenizer, - calibration_seq_length, - pad_calibration_inputs=False, - ): - super().__init__(model, tokenizer, calibration_seq_length) - self._model = model - - self._tokenizer = tokenizer - self._device = torch.device("cpu") - - self.vocab_size = model.vocab_size - - self.calibration_seq_length = calibration_seq_length - - self.pad_calibration_inputs = pad_calibration_inputs - - self.inputs = None - - if self.pad_calibration_inputs: - # This is needed for the pad_calibration_inputs option - # to work properly, the 0 token's embeddings are set to 0 so that - # the padded inputs will not affect the model numerics. This token isn't used - # commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs - # where it appears - try: - if isinstance(self._model.transformer.wte, nn.Embedding): - self._model.transformer.wte.weight.data[0, :] *= 0 - except: - print( - "Did not find embeddings in model.transformer.wte, disabling padding" - ) - self.pad_calibration_inputs = False - def add_input(self, args): if self.inputs is None: self.inputs = [MultiInput([arg]) for arg in args] @@ -236,7 +136,27 @@ def add_input(self, args): multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) ] - def get_recorded_inputs(self): + def record_inputs( + self, + calibration_tasks, + calibration_limit, + ): + try: + lm_eval.tasks.initialize_tasks() + except: + pass + + task_dict = get_task_dict(calibration_tasks) + print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) + + evaluate( + self, + task_dict, + limit=calibration_limit, + ) + return self + + def get_inputs(self): return self.inputs def _model_call(self, inps): @@ -247,7 +167,7 @@ def _model_call(self, inps): (T < self.calibration_seq_length and not self.pad_calibration_inputs) or # can't use inputs that actually use token we use for padding - (self.pad_calibration_inputs and 0 in inps) + (self.pad_calibration_inputs and self.pad_token in inps) ): # give random output return torch.randn( @@ -258,24 +178,20 @@ def _model_call(self, inps): if T >= self.calibration_seq_length: inps = inps[: self.calibration_seq_length] else: - inps = F.pad(inps, (0, self.calibration_seq_length - T)) - - max_new_tokens = 1 - ( - seq, - input_pos, - max_seq_length, - ) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( - self._model, inps, max_new_tokens, self.max_length - ) - x = seq.index_select(0, input_pos).view(1, -1) - self.add_input((x, input_pos)) + inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T)) + + model_in = self.input_prep_func(inps) + + self.add_input(model_in) # output `something` with correct shape to keep eval going return torch.randn( (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device ) + def _model_generate(self, context, max_length, eos_token_id): + raise Exception("unimplemented") + class MultiInput: @@ -305,8 +221,7 @@ class GenericGPTQRunner(fx.Interpreter): into the state_dict so that the quantized model weights/qparams can be loaded directly into the model. - This class is expected to work in concert with a GPTQSimpleQuantizer - class to define the specific type of quantization being done. + intended to be used in concert with a GPTQQuantizer class to define the quantization mode. """ def __init__( @@ -323,9 +238,9 @@ def __init__( } # trace model for one input - one_input = [multi.values[0] for multi in inputs] # pyre-ignore[16] + one_input = [multi.values[0].cpu() for multi in inputs] # pyre-ignore[16] exported_model = torch._dynamo.export( - model, aten_graph=True, pre_dispatch=True, tracing_mode="fake" + model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake" )(*one_input) super().__init__(exported_model.graph_module) @@ -338,7 +253,7 @@ def __init__( self.groupsize = groupsize self.inputs = inputs self.gptq_done = False - self.debug = False + self.debug = True def configure_quantization_mode( self, @@ -348,7 +263,7 @@ def configure_quantization_mode( combine_qparams_list_func, make_names_and_values_dict_func, skip_layer_func, - dyn_quant_func = None, + act_fake_quant_func = None, ): # these functions need to already be curried with all inputs other than weight, qparams @@ -373,7 +288,8 @@ def configure_quantization_mode( self.make_names_and_values_dict_func = make_names_and_values_dict_func # accepts [2d quantized tensor], [qparams], returns a dict of names, values to put in state_dict # note any final packing for storage should happen here - self.dyn_quant_func = dyn_quant_func + # `act_fake_quant_func` + self.act_fake_quant_func = act_fake_quant_func # accepts [activation tensor], returns a fake-quantized activation tensor return self def run(self): @@ -454,8 +370,8 @@ def tensors_to_cuda(args): quantize_linear ): # calculate H instead of output (will run the linear eventually with updated weight) x = cur_args[0].float() - if self.dyn_quant_func is not None: - x = self.dyn_quant_func(x) + if self.act_fake_quant_func is not None: + x = self.act_fake_quant_func(x) shape = x.shape n = 1 if len(shape) == 2 else shape[0] H *= total_batches / (total_batches + n) @@ -467,7 +383,8 @@ def tensors_to_cuda(args): else: # get output if its not a linear out = super().call_function(target, cur_args, cur_kwargs) - + # if isinstance(out, torch.Tensor) and (out.isnan().max() or out.sum()==0 or out.isinf().max()): + # breakpoint() if isinstance(out, torch.Tensor): outputs.append(out.cpu()) else: @@ -513,7 +430,6 @@ def SQNR(x, y): print( "SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after) ) # matches - print( "SQNR for weight (can be low)", SQNR(W, DQ.cuda()) ) # fine to not match @@ -628,3 +544,689 @@ def faster_quant(self, H, W): all_qparams = self.combine_qparams_list_func(all_qparams) Q = self.quantize_func(DQ, all_qparams) return Q, DQ.to(orig_dtype), all_qparams + + +if TORCH_VERSION_AFTER_2_4: + from .quant_primitives import ( + get_group_qparams_symmetric, + group_quantize_tensor_symmetric, + per_token_dynamic_quant, + ) + + class GPTQQuantizer(Quantizer): + """ + This class implements a GPTQ Quantizer that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. + Unlike the base Quantizer class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement + __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. + + The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and + create_quantized_state_dict. Here is a description of each function. + + get_qparams_func: + A function that calculates the quantization qparams for an input tensor. + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + qparams: it can have any format but will need to be handled by the other defined functions below. + + quantize_func: + A function that applies quantization to an input tensor. It should be noted + that this function needs to be able to handle quantizing the entire weight tensor, a single group, + or a single column. + Args: + weight: A 2d weight tensor with non-integer dtype. + qparams: the output from get_qparams_func + Returns: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + + + dequantize_func: + A function that dequantizes an input quantized weight tensor. It should be noted + that this function needs to be able to handle dequantizing the entire weight tensor, a single group, + or a single column. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + weight: A 2d weight tensor with non-integer dtype. + + act_fake_quant_func (optional): + A function that (dynamically) quantizes activation to input + Args: + input: input Tensor in f32/bf16/f16 + Returns: + output: dynamically quantized and dequantized Tensor (with the same dtype as input) + + combine_qparams_list_func: + A function that combines several qparams into one qparam. + Args: + qparams_list: a list of qparams objects, each obtained by calling get_qparams_func + on a single group from a weight tensor + Returns: + qparams: an object of the same format as the qparams above. + + skip_layer_func: + A function that determines which linear layers should be skipped during GPTQ + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + skip: boolean indicating whether layer should be skipped + + make_names_and_values_dict_func: + A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they + should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the + corresponding quantized weights and qparams. + """ + + def __init__(self): + + assert self.get_qparams_func is not None + + assert self.quantize_func is not None + + assert self.dequantize_func is not None + + assert self.combine_qparams_list_func is not None + + # `make_names_and_values_dict_func`. + assert self.make_names_and_values_dict_func is not None + + @torch.no_grad() + def _create_quantized_state_dict( + self, + model, + inputs, + blocksize, + percdamp, + groupsize, + # `typing.Dict[, ]` to avoid runtime subscripting errors. + ) -> Dict: + print("Tracing model for GPTQ") + GPTQ_runner = GenericGPTQRunner( + model, + inputs, + blocksize, + percdamp, + groupsize, + ).configure_quantization_mode( + self.get_qparams_func, # pyre-ignore[16] + self.quantize_func, # pyre-ignore[16] + self.dequantize_func, # pyre-ignore[16] + self.combine_qparams_list_func, # pyre-ignore[16] + self.make_names_and_values_dict_func, # pyre-ignore[16] + self.skip_layer_func, # pyre-ignore[16] + self.act_fake_quant_func if hasattr(self, "act_fake_quant_func") else None, # pyre-ignore[16] + ) + print("Applying GPTQ to weights") + GPTQ_runner.run() + return GPTQ_runner.get_quantized_state_dict() + + def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module": + raise NotImplementedError("_convert_for_runtime not implemented") + + @torch.no_grad() + def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module: + pass + + + def linear_forward_8da4w( + x, + weight_int8, + scales, + zeros, + out_features, + groupsize, + precision, + ): + x = per_token_dynamic_quant(x) + # TODO: verify and remove following reshape code + # origin_x_size = x.size() + # x = x.reshape(-1, origin_x_size[-1]) + + # TODO: better API + # weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed) + n_bit = 4 + quant_min = -(2 ** (n_bit - 1)) + quant_max = 2 ** (n_bit - 1) - 1 + w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group( + weight_int8, + scales, + zeros, + quant_min, + quant_max, + torch.int8, + groupsize, + precision, + ) + + # x = x.to(torch.float16) + # w_dq = w_dq.to(torch.float16) + c = torch.nn.functional.linear(x, w_dq) + + # new_shape = origin_x_size[:-1] + (out_features,) + # c = c.reshape(new_shape) + + return c + + + class WeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, in_features: int, out_features: int, + bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True, + ) -> None: + super().__init__() + self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles) + if self.padding: + from model import find_multiple + self.origin_in_features = in_features + in_features = find_multiple(in_features, 1024) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" + if use_cuda: + self.register_buffer( + "weight", + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + ) + else: + self.register_buffer( + "weight", + torch.empty((out_features, in_features // 2), dtype=torch.uint8) + ) + self.register_buffer( + "scales_and_zeros", + torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(torch.bfloat16) + if self.padding: + import torch.nn.functional as F + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_int4( + input, + self.weight, self.scales_and_zeros, self.out_features, self.groupsize + ) + + + class Int8DynActInt4WeightLinear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + + in_features: int + out_features: int + weight: torch.Tensor + + """ + This module implements a dynamic quantized linear layer with int4 weight. + Weights are per channel groupwise quantized. Parameters of importance + groupsize: the number of elements in each quantized group + precision: precision of input and output. e.g. torch.float32 means input + activation is float32 and output is float32. + scales_precision: precision of per group scale. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias=True, + device=None, + dtype=None, + groupsize: int = 256, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__() + # always pad if needed since it becomes a noop at runtime if not needed + # self.origin_in_features = in_features + assert ( + in_features % groupsize == 0 + ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" + # in_features = _calc_padded_size_linear_int4( + # in_features, groupsize + # ) + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + # TODO: align groupsize naming + self.groupsize = groupsize + # Precision of the activation which also indicates + # output precision of the dynamically quantized linear layer + # that his module represents. + self.precision = precision + + # currently storing unpacked int8 weights + self.register_buffer( + "weight", + torch.empty((out_features, in_features), dtype=torch.int8), + ) + self.register_buffer( + "scales", + torch.empty( + (out_features, in_features // groupsize), + dtype=scales_precision, + ), + ) + self.register_buffer( + "zeros", + torch.empty( + (out_features, in_features // groupsize), + dtype=scales_precision, + ), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(self.precision) + # padding is removed for perf + # input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_8da4w( + input, + self.weight, + self.scales, + self.zeros, + self.out_features, + self.groupsize, + self.precision, + ) + + + def find_multiple(n: int, *args: Tuple[int]) -> int: + k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9] + if n % k == 0: + return n + return n + k - (n % k) + + def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): + k_divisible_by_groupsize = k % groupsize == 0 + if inner_k_tiles is not None: + k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0 + return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles + return k_divisible_by_groupsize + + def _calc_padded_size_linear_int4(k, groupsize=1): + return find_multiple(k, groupsize) + + def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + + def pack_scales_and_zeros(scales, zeros, precision=torch.float32): + assert scales.shape == zeros.shape + assert scales.dtype == precision + assert zeros.dtype == precision + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + def unpack_scales_and_zeros(scales_and_zeros): + assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 + assert scales_and_zeros.dtype == torch.float + return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + + def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: + setattr(module, name, WeightOnlyInt4Linear( + child.in_features, child.out_features, bias=False, + groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda + )) + else: + replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda) + + def replace_linear_8da4w( + module, + groupsize, + padding_allowed, + precision, + scales_precision, + ): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed: + setattr( + module, + name, + Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + precision=precision, + scales_precision=scales_precision, + ), + ) + else: + replace_linear_8da4w( + child, + groupsize, + padding_allowed, + precision, + scales_precision, + ) + + def pack_scales_and_zeros(scales, zeros): + assert scales.shape == zeros.shape + assert scales.dtype == torch.bfloat16 + assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + + class Int8DynActInt4WeightQuantizer(Quantizer): + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = False, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + inner_k_tiles: Optional[int] = None, + _is_gpt_fast: bool = False, + _use_cuda: bool = True, + ) -> None: + super().__init__() + if _is_gpt_fast: + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + else: + assert inner_k_tiles is None + self._is_gpt_fast = _is_gpt_fast + self._use_cuda = _use_cuda + self.inner_k_tiles = inner_k_tiles + self.groupsize: int = groupsize + self.padding_allowed: bool = padding_allowed + self.precision: torch.dtype = precision + self.scales_precision: torch.dtype = scales_precision + + @torch.no_grad() + def _create_quantized_state_dict( + self, model: torch.nn.Module + ) -> Dict[str, torch.Tensor]: + cur_state_dict = model.state_dict() + for fqn, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + # assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + assert ( + in_features % self.groupsize == 0 + ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0" + + weight = mod.weight.data + if not _check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): + if self.padding_allowed: + from model import find_multiple + import torch.nn.functional as F + print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + else: + print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it") + continue + ( + weight_int8, + scales, + zeros, + ) = group_quantize_tensor_symmetric( + weight.to(self.precision), + 4, # n_bit + self.groupsize, + self.scales_precision, + ) + if self._is_gpt_fast: + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int8.to(torch.int32), self.inner_k_tiles) + scales_and_zeros = pack_scales_and_zeros(scales, zeros) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") + else: + cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu") + cur_state_dict[f"{fqn}.scales"] = scales.to("cpu") + cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu") + # TODO: support bias? + + return cur_state_dict + + def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: + if self._is_gpt_fast: + # TODO: temporary path for gpt-fast, will remove later + replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + self.padding_allowed, + self._use_cuda, + ) + else: + replace_linear_8da4w( + model, + self.groupsize, + self.padding_allowed, + self.precision, + self.precision, + ) + return model + + def quantize( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + state_dict = self._create_quantized_state_dict(model) + model = self._convert_for_runtime(model) + # TODO: make it strict + model.load_state_dict(state_dict, strict=False) + return model + + + # TODO: consolidate with other quantizers + class Int4WeightQuantizer(Quantizer): + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = False, + precision: torch.dtype = torch.float32, + inner_k_tiles: Optional[int] = None, + _use_cuda: bool = True, + ) -> None: + super().__init__( + groupsize, + padding_allowed, + precision, + torch.float32, # scales_precision + inner_k_tiles, + True, # _is_gpt_fast + _use_cuda, + ) + + + class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer): + def __init__( + self, + blocksize, + percdamp, + groupsize, + inner_k_tiles=8, + padding_allowed=True, + precision=torch.float32, + _is_gpt_fast=False, + _use_cuda=True, + ): + self._is_gpt_fast = _is_gpt_fast + self._use_cuda = _use_cuda + self.blocksize = blocksize + self.percdamp = percdamp + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding_allowed = padding_allowed + self.precision = precision + + self.act_fake_quant_func = per_token_dynamic_quant + n_bit = 4 + self.get_qparams_func = lambda w: get_group_qparams_symmetric( + w, n_bit, groupsize, self.precision + ) + quant_min = -(2 ** (n_bit - 1)) + quant_max = 2 ** (n_bit - 1) - 1 + + self.quantize_func = lambda w, qparams: torch.ops.quantized_decomposed.quantize_per_channel_group( + w, qparams[0], qparams[1], quant_min, quant_max, torch.int8, groupsize + ) + + self.dequantize_func = lambda q, qparams: torch.ops.quantized_decomposed.dequantize_per_channel_group( + q, + qparams[0], + qparams[1], + quant_min, + quant_max, + torch.int8, + groupsize, + self.precision, + ) + + self.combine_qparams_list_func = lambda qparams_list: [ + torch.cat(x, dim=1) for x in zip(*qparams_list) + ] + # skip unless padding_allowed=True or its correctly sized + + self.skip_layer_func = lambda linear_weight: not ( + _check_linear_int4_k(linear_weight.shape[-1], groupsize) or padding_allowed + ) + + # we need to do the padding here, both for q and the qparams if necessary + + # TODO: this is the gpt-fast version, merge with the main version later + def make_names_and_values_dict_func_gpt_fast(q, qparams): + k = q.shape[1] + new_k = find_multiple(k, 1024) + # how much we need to pad the weight + delta_k = new_k - q.shape[1] + q = q.to(torch.int32) + final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) + scales = qparams[0].to(torch.bfloat16) + zeros = qparams[1].to(torch.bfloat16) + scales_and_zeros = pack_scales_and_zeros(scales, zeros) + # how many new groups we need for padded weight + delta_groups = new_k // groupsize - scales_and_zeros.shape[0] + final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) + return {"weight": final_q, "scales_and_zeros": final_s_and_z} + + def make_names_and_values_dict_func(q, qparams): + k = q.shape[1] + new_k = _calc_padded_size_linear_int4(k, groupsize) + # how much we need to pad the weight + delta_k = new_k - q.shape[1] + final_q = F.pad(q, pad=(0, delta_k)) + scales = qparams[0].to(self.precision) + zeros = qparams[1].to(self.precision) + return {"weight": final_q, "scales": scales, "zeros": zeros} + + self.make_names_and_values_dict_func = make_names_and_values_dict_func_gpt_fast if self._is_gpt_fast else make_names_and_values_dict_func + super().__init__() + + def _convert_for_runtime(self, model): + if self._is_gpt_fast: + # TODO: temporary path for gpt-fast, will remove later + replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + self.padding_allowed, + self._use_cuda, + ) + else: + replace_linear_8da4w( + model, + self.groupsize, + self.padding_allowed, + self.precision, + self.precision, + ) + return model + + def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module: + state_dict = self._create_quantized_state_dict( + model, + inputs, + self.blocksize, + self.percdamp, + self.groupsize, + ) + model = self._convert_for_runtime(model) + model.load_state_dict(state_dict, strict=False) + return model + + + # TODO: consolidate with other quantizers + class Int4WeightGPTQQuantizer(Int8DynActInt4WeightGPTQQuantizer): + + def __init__( + self, + tokenizer, + blocksize, + percdamp, + groupsize, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + inner_k_tiles=8, + padding_allowed=True, + precision=torch.float32, + _use_cuda=True, + ): + super().__init__( + tokenizer, + blocksize, + percdamp, + groupsize, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + inner_k_tiles=8, + padding_allowed=True, + precision=torch.float32, + _is_gpt_fast=True, + _use_cuda=_use_cuda, + ) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 80599cb71c..4bfb279769 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -10,6 +10,7 @@ from .quant_primitives import * # noqa: F403 from .utils import * # noqa: F403 from .weight_only import * # noqa: F403 +from .unified import * __all__ = [ "DynamicallyPerAxisQuantizedLinear", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index fb83f90b22..f797b058b2 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -15,9 +15,6 @@ and mixed GEMM kernels """ -import logging -from typing import Any, Dict, Tuple - import torch import torch.nn as nn import torch.nn.functional as F @@ -32,12 +29,7 @@ QuantizedLinearWeightBase, ) from .weight_only import WeightOnlyInt8QuantLinear - -_AFTER_TORCH_2_3_ONLY = [ - "Int8DynActInt4WeightQuantizer", - "Int8DynActInt4WeightGPTQQuantizer", -] - +from .unified import Quantizer, TwoStepQuantizer __all__ = [ "apply_weight_only_int8_quant", @@ -48,35 +40,17 @@ "swap_conv2d_1x1_to_linear", "Quantizer", "TwoStepQuantizer", -] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else []) - - -############################# Unified Quantization APIs ############################## -# API 1, single quantize call to create a quantized model with quantized state_dict -class Quantizer: - def quantize( - self, model: torch.nn.Module, *args: Any, **kwargs: Any - ) -> torch.nn.Module: - - pass - - -# API 2, flow that needs calibration or training -class TwoStepQuantizer: - def prepare( - self, model: torch.nn.Module, *args: Any, **kwargs: Any - ) -> torch.nn.Module: - - pass - - def convert( - self, model: torch.nn.Module, *args: Any, **kwargs: Any - ) -> torch.nn.Module: - - pass - +] -############################# Unified Quantization APIs ############################## +if TORCH_VERSION_AFTER_2_3: + from .GPTQ import ( + Int8DynActInt4WeightQuantizer, + Int8DynActInt4WeightGPTQQuantizer, + ) + __all__ += [ + "Int8DynActInt4WeightQuantizer", + "Int8DynActInt4WeightGPTQQuantizer", + ] def _replace_with_custom_fn_if_matches_filter( @@ -223,599 +197,5 @@ def replace_conv2d_1x1(conv): model, replace_conv2d_1x1, filter_fn=filter_fn ) - if TORCH_VERSION_AFTER_2_3: - from .quant_primitives import ( - get_group_qparams_symmetric, - group_quantize_tensor_symmetric, - per_token_dynamic_quant, - ) - - from .GPTQ import lm_eval_available - - if lm_eval_available: - from .GPTQ import ( - evaluate, - GenericGPTQRunner, - get_task_dict, - InputRecorder, - lm_eval, - MultiInput, - ) - else: - logging.info("lm_eval not available, skip defining GPTQQuantizer") - - - class GPTQQuantizer(Quantizer): - """ - This class implements a GPTQ Quantizer that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. - Unlike the base Quantizer class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement - __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. - - The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and - create_quantized_state_dict. Here is a description of each function. - - get_qparams_func: - A function that calculates the quantization qparams for an input tensor. - Args: - weight: A 2d weight tensor with non-integer dtype. - Returns: - qparams: it can have any format but will need to be handled by the other defined functions below. - - quantize_func: - A function that applies quantization to an input tensor. It should be noted - that this function needs to be able to handle quantizing the entire weight tensor, a single group, - or a single column. - Args: - weight: A 2d weight tensor with non-integer dtype. - qparams: the output from get_qparams_func - Returns: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - - - dequantize_func: - A function that dequantizes an input quantized weight tensor. It should be noted - that this function needs to be able to handle dequantizing the entire weight tensor, a single group, - or a single column. - Args: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - qparams: the output from get_qparams_func - Returns: - weight: A 2d weight tensor with non-integer dtype. - - dyn_quant_func (optional): - A function that dynamically quantizes inputs - Args: - input: input Tensor in f32/bf16/f16 - Returns: - output: dynamically quantized and dequantized Tensor (with the same dtype as input) - - combine_qparams_list_func: - A function that combines several qparams into one qparam. - Args: - qparams_list: a list of qparams objects, each obtained by calling get_qparams_func - on a single group from a weight tensor - Returns: - qparams: an object of the same format as the qparams above. - - skip_layer_func: - A function that determines which linear layers should be skipped during GPTQ - Args: - weight: A 2d weight tensor with non-integer dtype. - Returns: - skip: boolean indicating whether layer should be skipped - - make_names_and_values_dict_func: - A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they - should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. - Args: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - qparams: the output from get_qparams_func - Returns: - names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the - corresponding quantized weights and qparams. - """ - - def __init__(self): - - assert self.get_qparams_func is not None - - assert self.quantize_func is not None - - assert self.dequantize_func is not None - - assert self.combine_qparams_list_func is not None - - # `make_names_and_values_dict_func`. - assert self.make_names_and_values_dict_func is not None - - @staticmethod - def get_inputs( - model, - tokenizer, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs, - ) -> "MultiInput": - input_recorder = InputRecorder( - model, - tokenizer, - calibration_seq_length, - pad_calibration_inputs, - ) - - try: - - lm_eval.tasks.initialize_tasks() - except: - pass - - task_dict = get_task_dict(calibration_tasks) - print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) - - evaluate( - input_recorder, - task_dict, - limit=calibration_limit, - ) - inputs = input_recorder.get_recorded_inputs() - assert inputs is not None, ( - f"No inputs were collected, use a task other than {calibration_tasks}, " - + "use option pad_calibration_inputs, or decrease calibration_sequence_length (currently " - + f"{calibration_seq_length})" - ) - print(f"Obtained {len(inputs[0].values)} calibration samples") - return inputs - - @torch.no_grad() - def _create_quantized_state_dict( - self, - model, - tokenizer, - blocksize, - percdamp, - groupsize, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs, - # `typing.Dict[, ]` to avoid runtime subscripting errors. - ) -> Dict: - inputs = GPTQQuantizer.get_inputs( - model, - tokenizer, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs, - ) - print("Tracing model for GPTQ") - GPTQ_runner = GenericGPTQRunner( - model, - inputs, - blocksize, - percdamp, - groupsize, - ).configure_quantization_mode( - self.get_qparams_func, # pyre-ignore[16] - self.quantize_func, # pyre-ignore[16] - self.dequantize_func, # pyre-ignore[16] - self.combine_qparams_list_func, # pyre-ignore[16] - self.make_names_and_values_dict_func, # pyre-ignore[16] - self.skip_layer_func, # pyre-ignore[16] - self.dyn_quant_func if hasattr(self, "dyn_quant_func") else None, # pyre-ignore[16] - ) - print("Applying GPTQ to weights") - GPTQ_runner.run() - return GPTQ_runner.get_quantized_state_dict() - - def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module": - raise NotImplementedError("_convert_for_runtime not implemented") - - @torch.no_grad() - def quantize(self, model: torch.nn.Module, **kwargs: Any) -> torch.nn.Module: - state_dict = self._create_quantized_state_dict( - model, - self.tokenizer, - self.blocksize, - self.percdamp, - self.groupsize, - self.calibration_tasks, - self.calibration_limit, - self.calibration_seq_length, - self.pad_calibration_inputs, - ) - model = self._convert_for_runtime(model) - model.load_state_dict(state_dict, strict=False) - return model - - - def linear_forward_8da4w( - x, - weight_int8, - scales, - zeros, - out_features, - group_size, - precision, - ): - x = per_token_dynamic_quant(x) - # TODO: verify and remove following reshape code - # origin_x_size = x.size() - # x = x.reshape(-1, origin_x_size[-1]) - - # TODO: better API - # weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed) - n_bit = 4 - quant_min = -(2 ** (n_bit - 1)) - quant_max = 2 ** (n_bit - 1) - 1 - w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group( - weight_int8, - scales, - zeros, - quant_min, - quant_max, - torch.int8, - group_size, - precision, - ) - - # x = x.to(torch.float16) - # w_dq = w_dq.to(torch.float16) - c = torch.nn.functional.linear(x, w_dq) - - # new_shape = origin_x_size[:-1] + (out_features,) - # c = c.reshape(new_shape) - - return c - - - class Int8DynActInt4WeightLinear(torch.nn.Module): - __constants__ = ["in_features", "out_features"] - - in_features: int - out_features: int - weight: torch.Tensor - - """ - This module implements a dynamic quantized linear layer with int4 weight. - Weights are per channel groupwise quantized. Parameters of importance - group_size: the number of elements in each quantized group - precision: precision of input and output. e.g. torch.float32 means input - activation is float32 and output is float32. - scales_precision: precision of per group scale. - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias=True, - device=None, - dtype=None, - group_size: int = 256, - precision: torch.dtype = torch.float32, - scales_precision: torch.dtype = torch.float32, - ) -> None: - super().__init__() - # always pad if needed since it becomes a noop at runtime if not needed - # self.origin_in_features = in_features - assert ( - in_features % group_size == 0 - ), f"require in_features:{in_features} % group_size:{group_size} == 0" - # in_features = _calc_padded_size_linear_int4( - # in_features, group_size - # ) - self.in_features = in_features - self.out_features = out_features - assert not bias, "require bias=False" - # TODO: align groupsize naming - self.group_size = group_size - # Precision of the activation which also indicates - # output precision of the dynamically quantized linear layer - # that his module represents. - self.precision = precision - - # currently storing unpacked int8 weights - self.register_buffer( - "weight", - torch.empty((out_features, in_features), dtype=torch.int8), - ) - self.register_buffer( - "scales", - torch.empty( - (out_features, in_features // group_size), - dtype=scales_precision, - ), - ) - self.register_buffer( - "zeros", - torch.empty( - (out_features, in_features // group_size), - dtype=scales_precision, - ), - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = input.to(self.precision) - # padding is removed for perf - # input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) - return linear_forward_8da4w( - input, - self.weight, - self.scales, - self.zeros, - self.out_features, - self.group_size, - self.precision, - ) - - - from functools import reduce - from math import gcd - - - def find_multiple(n: int, *args: Tuple[int]) -> int: - # TODO: this change is reverted right now in gpt-fast - k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9] - if n % k == 0: - return n - return n + k - (n % k) - - - def _check_linear_int4_k(k, group_size=1): - return k % group_size == 0 - - - def _calc_padded_size_linear_int4(k, groupsize=1): - return find_multiple(k, groupsize) - - - def pack_scales_and_zeros(scales, zeros, precision=torch.float32): - assert scales.shape == zeros.shape - assert scales.dtype == precision - assert zeros.dtype == precision - return ( - torch.cat( - [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), - ], - 2, - ) - .transpose(0, 1) - .contiguous() - ) - - - def unpack_scales_and_zeros(scales_and_zeros): - assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - assert scales_and_zeros.dtype == torch.float - return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) - - - def replace_linear_8da4w( - module, - group_size, - padding_allowed, - precision, - scales_precision, - ): - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - if _check_linear_int4_k(child.in_features, group_size) or padding_allowed: - setattr( - module, - name, - Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - group_size=group_size, - precision=precision, - scales_precision=scales_precision, - ), - ) - else: - replace_linear_8da4w( - child, - group_size, - padding_allowed, - precision, - scales_precision, - ) - - - class Int8DynActInt4WeightQuantizer(Quantizer): - def __init__( - self, - group_size: int = 256, - padding_allowed: bool = False, - precision: torch.dtype = torch.float32, - scales_precision: torch.dtype = torch.float32, - ) -> None: - self.group_size: int = group_size - self.padding_allowed: bool = padding_allowed - self.precision: torch.dtype = precision - self.scales_precision: torch.dtype = scales_precision - # assert group_size in [32, 64, 128, 256] - - @torch.no_grad() - def _create_quantized_state_dict( - self, model: torch.nn.Module - ) -> Dict[str, torch.Tensor]: - cur_state_dict = model.state_dict() - for fqn, mod in model.named_modules(): - if isinstance(mod, torch.nn.Linear): - assert not mod.bias - out_features = mod.out_features - in_features = mod.in_features - # assert out_features % 8 == 0, "require out_features % 8 == 0" - print(f"linear: {fqn}, in={in_features}, out={out_features}") - - assert ( - in_features % self.group_size == 0 - ), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0" - - weight = mod.weight.data - """ - if not _check_linear_int4_k( - in_features, self.group_size - ): - if self.padding_allowed: - print( - f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" - ) - padded_in_features = _calc_padded_size_linear_int4( - in_features, self.group_size - ) - weight = F.pad( - weight, pad=(0, padded_in_features - in_features) - ) - else: - raise RuntimeError( - f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " - + "and that group_size" - ) - """ - ( - weight_int8, - scales, - zeros, - ) = group_quantize_tensor_symmetric( - weight.to(self.precision), - 4, # n_bit - self.group_size, - self.scales_precision, - ) - cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu") - cur_state_dict[f"{fqn}.scales"] = scales.to("cpu") - cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu") - # TODO: support bias? - - return cur_state_dict - - def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: - replace_linear_8da4w( - model, - self.group_size, - self.padding_allowed, - self.precision, - self.scales_precision, - ) - return model - - def quantize( - self, model: torch.nn.Module, *args: Any, **kwargs: Any - ) -> torch.nn.Module: - state_dict = self._create_quantized_state_dict(model) - model = self._convert_for_runtime(model) - # TODO: make it strict - model.load_state_dict(state_dict, strict=False) - return model - - - class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer): - - def __init__( - self, - tokenizer, - blocksize, - percdamp, - groupsize, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs, - inner_k_tiles=8, - padding_allowed=True, - precision=torch.float32, - ): - - self.tokenizer = tokenizer - - self.blocksize = blocksize - - self.percdamp = percdamp - - self.groupsize = groupsize - - self.calibration_tasks = calibration_tasks - - self.calibration_limit = calibration_limit - - self.calibration_seq_length = calibration_seq_length - - self.pad_calibration_inputs = pad_calibration_inputs - - self.inner_k_tiles = inner_k_tiles - - self.padding_allowed = padding_allowed - - self.precision = precision - - self.dyn_quant_func = per_token_dynamic_quant - n_bit = 4 - - self.get_qparams_func = lambda w: get_group_qparams_symmetric( - w, n_bit, groupsize, self.precision - ) - quant_min = -(2 ** (n_bit - 1)) - quant_max = 2 ** (n_bit - 1) - 1 - - self.quantize_func = lambda w, qparams: torch.ops.quantized_decomposed.quantize_per_channel_group( - w, qparams[0], qparams[1], quant_min, quant_max, torch.int8, groupsize - ) - - self.dequantize_func = lambda q, qparams: torch.ops.quantized_decomposed.dequantize_per_channel_group( - q, - qparams[0], - qparams[1], - quant_min, - quant_max, - torch.int8, - groupsize, - self.precision, - ) - - self.combine_qparams_list_func = lambda qparams_list: [ - torch.cat(x, dim=1) for x in zip(*qparams_list) - ] - # skip unless padding_allowed=True or its correctly sized - - self.skip_layer_func = lambda linear_weight: not ( - _check_linear_int4_k(linear_weight.shape[-1], groupsize) or padding_allowed - ) - - # we need to do the padding here, both for q and the qparams if necessary - - def make_names_and_values_dict_func(q, qparams): - k = q.shape[1] - new_k = _calc_padded_size_linear_int4(k, groupsize) - # how much we need to pad the weight - delta_k = new_k - q.shape[1] - final_q = F.pad(q, pad=(0, delta_k)) - scales = qparams[0].to(self.precision) - zeros = qparams[1].to(self.precision) - # scales_and_zeros = pack_scales_and_zeros(*qparams, precision=self.precision) - # how many new groups we need for padded weight - # delta_groups = new_k // groupsize - scales_and_zeros.shape[0] - # TODO: split scales and zero_points - # final_s_and_z = F.pad( - # scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1 - # ) - return {"weight": final_q, "scales": scales, "zeros": zeros} - - self.make_names_and_values_dict_func = make_names_and_values_dict_func - super().__init__() - - def _convert_for_runtime(self, model): - replace_linear_8da4w( - model, - self.groupsize, - self.padding_allowed, - self.precision, - self.precision, - ) - return model + from .GPTQ import Int8DynActInt4WeightQuantizer, Int8DynActInt4WeightGPTQQuantizer diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 3e05dd3c42..e629ae32bc 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -11,7 +11,6 @@ from torch.library import impl from torchao.kernel.intmm import int_scaled_matmul -from .utils import TORCH_VERSION_AFTER_2_4 from torchao.kernel.intmm import safe_int_mm from .utils import TORCH_VERSION_AFTER_2_3 diff --git a/torchao/quantization/unified.py b/torchao/quantization/unified.py new file mode 100644 index 0000000000..16112ac0f0 --- /dev/null +++ b/torchao/quantization/unified.py @@ -0,0 +1,29 @@ +import torch +from typing import Any + +############################# Unified Quantization APIs ############################## +# API 1, single quantize call to create a quantized model with quantized state_dict +class Quantizer: + def quantize( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + + pass + + +# API 2, flow that needs calibration or training +class TwoStepQuantizer: + def prepare( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + + pass + + def convert( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + + pass + + +############################# Unified Quantization APIs ############################## From 251fddb13a05839a2f1fec3d7677e7c329b0efe5 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Sat, 30 Mar 2024 22:17:55 -0700 Subject: [PATCH 05/21] Github Action Strategy Refactor (#97) * Increase nightly regression test coverage Have a config that that tries out different runners and pytorch versions * Update regression_test.yml --------- Co-authored-by: cpuhrsch --- .github/workflows/regression_test.yml | 120 +++++--------------------- 1 file changed, 22 insertions(+), 98 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 392ee1947c..fd3b5b4943 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -9,114 +9,38 @@ on: - main jobs: - test-cuda-2-2-2: - runs-on: 4-core-ubuntu-gpu-t4 + test: + strategy: + matrix: + include: + - name: CUDA 2.2.2 + runs-on: 4-core-ubuntu-gpu-t4 + torch-spec: 'torch==2.2.2' + - name: CUDA 2.3 RC + runs-on: 4-core-ubuntu-gpu-t4 + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' + - name: CUDA Nightly + runs-on: 4-core-ubuntu-gpu-t4 + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' + - name: CPU + runs-on: 32-core-ubuntu + torch-spec: 'torch --index-url https://download.pytorch.org/whl/cpu' + - name: Nightly CPU + runs-on: 32-core-ubuntu + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu' + runs-on: ${{ matrix.runs-on }} steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: '3.9' - name: Install dependencies run: | python -m pip install --upgrade pip - pip install torch==2.2.2 - pip install -r requirements.txt - pip install -r dev-requirements.txt - - - name: Install package - run: | - pip install . - - - name: Run tests - run: | - pytest test --verbose -s -x - - test-cuda-2-3-rc: - runs-on: 4-core-ubuntu-gpu-t4 - steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 - pip install -r requirements.txt - pip install -r dev-requirements.txt - - test-cuda-nightly: - runs-on: 4-core-ubuntu-gpu-t4 - steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 - pip install -r requirements.txt - pip install -r dev-requirements.txt - - - - name: Install package - run: | - pip install . - - - name: Run tests - run: | - pytest test --verbose -s -x - - test-cpu: - runs-on: 32-core-ubuntu - steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install -r requirements.txt - pip install -r dev-requirements.txt - - - - name: Install package - run: | - pip install . - - - name: Run tests - run: | - pytest test --verbose -s -x - - test-nightly-cpu: - runs-on: 32-core-ubuntu - steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + pip install ${{ matrix.torch-spec }} pip install -r requirements.txt pip install -r dev-requirements.txt From 8c62eb04bee7b2b0b2c5629a68ddcc7a80424f39 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Mon, 1 Apr 2024 11:26:34 -0700 Subject: [PATCH 06/21] Reenable nf4 compile smoke test (#101) --- test/dtypes/test_nf4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 5cd967f9d2..c71fcd25b6 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -192,7 +192,6 @@ def test_smoketest_linear(self): out1 = torch.nn.functional.linear(inp, a) out2 = torch.nn.functional.linear(inp, a_nf4) - @unittest.skipIf(torch.__version__.split('+')[0] == '2.2.1', "Broken on stable.") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_smoketest_linear_compile(self): for dtype in [torch.bfloat16, torch.float16]: From a8704f8ff7cbd01390b9bcaa6640f2eae311294d Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 1 Apr 2024 15:45:51 -0400 Subject: [PATCH 07/21] [test] `get_group_qparams_symmetric` matches observer (#94) --- test/quantization/test_quant_primitives.py | 50 ++++++++++++++++++++++ torchao/quantization/quant_primitives.py | 1 + 2 files changed, 51 insertions(+) create mode 100644 test/quantization/test_quant_primitives.py diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py new file mode 100644 index 0000000000..fef3f83dcc --- /dev/null +++ b/test/quantization/test_quant_primitives.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# mypy: ignore-errors +# This test takes a long time to run +import unittest +import torch +from torchao.quantization.quant_primitives import get_group_qparams_symmetric +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 + +class TestQuantPrimitives(unittest.TestCase): + SEED = 123 + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower") + def test_get_group_qparams_symmetric(self): + """ + Test that `get_group_qparams_symmetric` produces the exact same scales as + `PerChannelMinMaxObserver._calculate_qparams`. + """ + n_bit = 4 + qmin = -(2 ** (n_bit - 1)) + qmax = 2 ** (n_bit - 1) - 1 + eps = torch.finfo(torch.float32).eps + groupsize = 256 + torch.manual_seed(self.SEED) + weight = torch.randn(100, 256).to(torch.float16) + + # calculate observer scales + obs = torch.ao.quantization.PerChannelMinMaxObserver( + ch_axis=0, + qscheme=torch.per_channel_symmetric, + quant_min=qmin, + quant_max=qmax, + # This is needed to ensure `min_val` and `max_val` are fp16, + # otherwise they default to fp32 and the qparams will be slightly off + factory_kwargs={"dtype": torch.float16} + ) + obs(weight) + (scale_obs, _) = obs.calculate_qparams() + scale_obs = scale_obs.reshape(weight.shape[0], -1) + + # assert that scales are identical + (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize) + torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0) + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index e629ae32bc..5baa289729 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -470,6 +470,7 @@ def groupwise_affine_dequantize_tensor( ) +# TODO: replace this with torch.ao.quantization.PerChannelMinMaxObserver def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float32): # needed for GPTQ with padding if groupsize > w.shape[-1]: From 7b5a097f2d487076d374faab2d7bb4fa0336ec29 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 1 Apr 2024 13:10:44 -0700 Subject: [PATCH 08/21] Fix import versions for GPTQ (#105) --- test/quantization/model.py | 1 + torchao/quantization/GPTQ.py | 7 ++++--- torchao/quantization/quant_api.py | 7 ++++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/test/quantization/model.py b/test/quantization/model.py index b9705313e6..940f109d04 100644 --- a/test/quantization/model.py +++ b/test/quantization/model.py @@ -12,6 +12,7 @@ from torch.nn import functional as F def prepare_inputs_for_model(inps): + inps = inps.squeeze(0) # setup inputs in correct format max_new_tokens = 1 T = inps.size(0) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 304a84ac56..062e8ae344 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -20,7 +20,7 @@ # from model import Transformer # pyre-ignore[21] from torch.utils._pytree import tree_flatten, tree_unflatten -from .utils import TORCH_VERSION_AFTER_2_4 +from .utils import TORCH_VERSION_AFTER_2_3 from typing import Any, Dict, Tuple, Optional from .unified import Quantizer from functools import reduce @@ -89,7 +89,7 @@ def __init__( # for model self.input_prep_func = ( input_prep_func if input_prep_func is not None - else lambda x: x + else lambda x: (x,) ) self.pad_calibration_inputs = pad_calibration_inputs @@ -180,6 +180,7 @@ def _model_call(self, inps): else: inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T)) + inps = inps.unsqueeze(0) model_in = self.input_prep_func(inps) self.add_input(model_in) @@ -546,7 +547,7 @@ def faster_quant(self, H, W): return Q, DQ.to(orig_dtype), all_qparams -if TORCH_VERSION_AFTER_2_4: +if TORCH_VERSION_AFTER_2_3: from .quant_primitives import ( get_group_qparams_symmetric, group_quantize_tensor_symmetric, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f797b058b2..a17daf8697 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -46,10 +46,14 @@ from .GPTQ import ( Int8DynActInt4WeightQuantizer, Int8DynActInt4WeightGPTQQuantizer, + Int4WeightQuantizer, + Int4WeightGPTQQuantizer, ) __all__ += [ "Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightGPTQQuantizer", + "Int4WeightQuantizer", + "Int4WeightGPTQQuantizer", ] @@ -196,6 +200,3 @@ def replace_conv2d_1x1(conv): _replace_with_custom_fn_if_matches_filter( model, replace_conv2d_1x1, filter_fn=filter_fn ) - -if TORCH_VERSION_AFTER_2_3: - from .GPTQ import Int8DynActInt4WeightQuantizer, Int8DynActInt4WeightGPTQQuantizer From 83176cdd135e9c9ee734952477f04e7a38f90dd9 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Mon, 1 Apr 2024 13:57:25 -0700 Subject: [PATCH 09/21] Add 2.3 RC for CUDA and CPU. (#106) --- .github/workflows/regression_test.yml | 9 ++++++--- test/quantization/test_quant_api.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index fd3b5b4943..560df57c3d 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -18,13 +18,16 @@ jobs: torch-spec: 'torch==2.2.2' - name: CUDA 2.3 RC runs-on: 4-core-ubuntu-gpu-t4 - torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' + torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/test/cu121' - name: CUDA Nightly runs-on: 4-core-ubuntu-gpu-t4 torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' - - name: CPU + - name: CPU 2.2.2 + runs-on: 32-core-ubuntu + torch-spec: 'torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu' + - name: CPU 2.3 RC runs-on: 32-core-ubuntu - torch-spec: 'torch --index-url https://download.pytorch.org/whl/cpu' + torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/test/cpu' - name: Nightly CPU runs-on: 32-core-ubuntu torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu' diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 156fdbd78a..b32a07be01 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -26,6 +26,7 @@ ) from torchao.quantization.utils import ( TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_4, ) from pathlib import Path from sentencepiece import SentencePieceProcessor @@ -136,7 +137,7 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): compiled = m(*example_inputs) torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_8da4w_quantizer(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear From a7963ae5bebaca83ad261fb5d0ac41b08ec648e9 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Mon, 1 Apr 2024 15:43:45 -0700 Subject: [PATCH 10/21] Reenable flaky tests (#110) --- test/integration/test_integration.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 6873f7fbac..801ca10bc2 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -902,7 +902,6 @@ def test_int8_dynamic_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skip("flaky test, will fix in another PR") def test_int8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype @@ -976,7 +975,6 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skip("flaky test, will fix in another PR") def test_int8_weight_only_quant_subclass_api(self, device, dtype): self._test_lin_weight_subclass_api_impl( change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype @@ -1157,7 +1155,6 @@ def test_save_load_dqtensors(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @torch.no_grad() - @unittest.skip("flaky test, will fix in another PR") def test_save_load_int8woqtensors(self, device, dtype): self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype) From 046dc985de6d5eac05c6575cc71505687e3aadf1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 1 Apr 2024 18:12:04 -0700 Subject: [PATCH 11/21] Some fixes in the recently added quantizer (#111) --- test/quantization/model.py | 4 ++++ torchao/quantization/GPTQ.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/quantization/model.py b/test/quantization/model.py index 940f109d04..19a826781b 100644 --- a/test/quantization/model.py +++ b/test/quantization/model.py @@ -12,6 +12,10 @@ from torch.nn import functional as F def prepare_inputs_for_model(inps): + # this is because input from lm-eval is 2d + if input.dim() != 2: + raise ValueError(f"Expected input to be of dim 2, but got {input.dim()}") + inps = inps.squeeze(0) # setup inputs in correct format max_new_tokens = 1 diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 062e8ae344..c22ffa79d4 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -1225,9 +1225,9 @@ def __init__( calibration_limit, calibration_seq_length, pad_calibration_inputs, - inner_k_tiles=8, - padding_allowed=True, - precision=torch.float32, - _is_gpt_fast=True, + inner_k_tiles=inner_k_tiles, + padding_allowed=padding_allowed, + precision=precision, + _is_gpt_fast=_is_gpt_fast, _use_cuda=_use_cuda, ) From b0a333c73098e52be830942245675b96230e0e1c Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Wed, 3 Apr 2024 13:18:44 -0400 Subject: [PATCH 12/21] add int4 gptq and eval (#116) * add int4 gptq and eval Summary: adding int4 gptq and eval support. Also fixed a few bugs relating to quantizing the activation both during gptq calculation and when calculating the output. Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d29b6d73c90dec5171e12938afee25e5f42e042d Pull Request resolved: https://github.com/pytorch-labs/ao/pull/115 * add int4 gptq and eval Summary: adding int4 gptq and eval support. Also fixed a few bugs relating to quantizing the activation both during gptq calculation and when calculating the output. Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: * remove debug from GPTQ Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/quantization/model.py | 13 +- test/quantization/test_quant_api.py | 149 ++++- torchao/quantization/GPTQ.py | 661 ++++++++++++----------- torchao/quantization/quant_api.py | 11 +- torchao/quantization/quant_primitives.py | 19 +- torchao/quantization/utils.py | 8 +- 6 files changed, 522 insertions(+), 339 deletions(-) diff --git a/test/quantization/model.py b/test/quantization/model.py index 19a826781b..e851901c41 100644 --- a/test/quantization/model.py +++ b/test/quantization/model.py @@ -10,15 +10,15 @@ import torch.nn as nn from torch import Tensor from torch.nn import functional as F +from torchao.quantization.utils import find_multiple -def prepare_inputs_for_model(inps): +def prepare_inputs_for_model(inps, max_new_tokens=1): # this is because input from lm-eval is 2d - if input.dim() != 2: - raise ValueError(f"Expected input to be of dim 2, but got {input.dim()}") + if inps.dim() != 2: + raise ValueError(f"Expected input to be of dim 2, but got {inps.dim()}") inps = inps.squeeze(0) # setup inputs in correct format - max_new_tokens = 1 T = inps.size(0) T_new = T + max_new_tokens seq = torch.empty(T_new, dtype=inps.dtype, device=inps.device) @@ -27,11 +27,6 @@ def prepare_inputs_for_model(inps): x = seq.index_select(0, input_pos).view(1, -1) return (x, input_pos) -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - @dataclass class ModelArgs: block_size: int = 2048 diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b32a07be01..700c9c9b98 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -151,8 +151,8 @@ def test_8da4w_quantizer(self): m(*example_inputs) @unittest.skip("skipping until we get checkpoints for gpt-fast") - def test_gptq_quantizer(self): - from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder + def test_8da4w_gptq_quantizer(self): + from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder, TransformerEvalWrapper # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 device = "cpu" @@ -161,6 +161,7 @@ def test_gptq_quantizer(self): checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) model = model.to(dtype=precision, device=device) + model.eval() tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), tokenizer_path tokenizer = SentencePieceProcessor( # pyre-ignore[28] @@ -190,12 +191,60 @@ def test_gptq_quantizer(self): blocksize, percdamp, groupsize, + precision=precision, ) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, inputs) - compiled = torch.compile(model, mode="max-autotune") - with torch.no_grad(): - compiled(inputs[0].values[0], inputs[1].values[0]) + result=TransformerEvalWrapper( + model, + tokenizer, + model.config.block_size, + prepare_inputs_for_model, + device, + ).run_eval( + ["wikitext"], + 1, + ) + + assert result['results']['wikitext']['word_perplexity,none'] < 7.88, ( + f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}" + ) + + @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + def test_8da4w_quantizer_eval(self): + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + from torchao.quantization.GPTQ import TransformerEvalWrapper + + precision = torch.bfloat16 + device = "cpu" + checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") + model = Transformer.from_name(checkpoint_path.parent.name) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + model.eval() + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor( # pyre-ignore[28] + model_file=str(tokenizer_path) + ) + + quantizer = Int8DynActInt4WeightQuantizer(groupsize=128, precision=precision) + q_model = quantizer.quantize(model) + result=TransformerEvalWrapper( + q_model, + tokenizer, + q_model.config.block_size, + prepare_inputs_for_model, + device, + ).run_eval( + ["wikitext"], + 1, + ) + assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( + f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" + ) @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_gptq_quantizer_gpt_fast(self): @@ -248,5 +297,95 @@ def test_gptq_quantizer_gpt_fast(self): with torch.no_grad(): compiled(inputs[0].values[0], inputs[1].values[0]) + @unittest.skip("skipping until we get checkpoints for gpt-fast") + def test_gptq_quantizer_int4wo(self): + from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper + # should be similar to TorchCompileDynamicQuantizer + precision = torch.bfloat16 + device = "cuda" + checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") + model = Transformer.from_name(checkpoint_path.parent.name) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device="cpu") + model.eval() + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor( # pyre-ignore[28] + model_file=str(tokenizer_path) + ) + blocksize = 128 + percdamp = 0.01 + groupsize = 128 + calibration_tasks = ["wikitext"] + calibration_limit = 1 + calibration_seq_length = 100 + input_prep_func = prepare_inputs_for_model + pad_calibration_inputs = False + + inputs = InputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + device="cpu", + ).record_inputs( + calibration_tasks, + calibration_limit, + ).get_inputs() + + quantizer = Int4WeightOnlyGPTQQuantizer( + blocksize, + percdamp, + groupsize, + ) + model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) + + model = quantizer.quantize(model, inputs).cuda() + result = TransformerEvalWrapper( + model.cuda(), + tokenizer, + model.config.block_size, + prepare_inputs_for_model, + device, + ).run_eval( + ["wikitext"], + 1, + ) + assert result['results']['wikitext']['word_perplexity,none'] < 7.77, ( + f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" + ) + + @unittest.skip("skipping until we get checkpoints for gpt-fast") + def test_eval_wrapper(self): + from torchao.quantization.GPTQ import TransformerEvalWrapper + precision = torch.bfloat16 + device = "cuda" + checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") + model = Transformer.from_name(checkpoint_path.parent.name) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + model.eval() + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor( # pyre-ignore[28] + model_file=str(tokenizer_path) + ) + result=TransformerEvalWrapper( + model, + tokenizer, + model.config.block_size, + prepare_inputs_for_model, + device, + ).run_eval( + ["wikitext"], + 1, + ) + assert result['results']['wikitext']['word_perplexity,none']<7.77, ( + f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" + ) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index c22ffa79d4..559ab54f7d 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -17,15 +17,18 @@ import torch.nn as nn import torch.nn.functional as F -# from model import Transformer # pyre-ignore[21] from torch.utils._pytree import tree_flatten, tree_unflatten -from .utils import TORCH_VERSION_AFTER_2_3 -from typing import Any, Dict, Tuple, Optional +from .utils import TORCH_VERSION_AFTER_2_3, find_multiple +from typing import Any, Dict, Optional from .unified import Quantizer -from functools import reduce -from math import gcd +from .quant_primitives import ( + get_groupwise_affine_qparams, + groupwise_affine_quantize_tensor_from_qparams, + groupwise_affine_dequantize_tensor_from_qparams, + pack_tinygemm_scales_and_zeros, +) aten = torch.ops.aten ## eval.py ## @@ -51,6 +54,21 @@ else: logging.info("lm_eval is not installed, GPTQ may not be usable") +add_ons = [] + +if lm_eval_available: + add_ons += ["InputRecorder", "TransformerEvalWrapper"] + +if TORCH_VERSION_AFTER_2_3: + add_ons += ["Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightGPTQQuantizer"] + + +__all__ = [ + "MultiInput", + "WeightOnlyInt4Linear", + "Int4WeightOnlyGPTQQuantizer", +] + add_ons + if lm_eval_available: class InputRecorder(eval_wrapper): """ @@ -193,6 +211,61 @@ def _model_call(self, inps): def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented") + class TransformerEvalWrapper(InputRecorder): + """ + A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. + """ + def __init__( + self, + model, + tokenizer, + max_seq_length, + input_prep_func=None, + device="cuda" + ): + super().__init__(None, None) + self._model = model + self._tokenizer = tokenizer + self._device = torch.device(device) + self._max_seq_length = max_seq_length + + # need to take inps and convert to corrent input + # for model + self.input_prep_func = ( + input_prep_func if input_prep_func is not None + else lambda x: (x,) + ) + + def _model_call(self, inps): + # TODO: make batches work + input = self.input_prep_func(inps) + + max_seq_length = min(inps.size(1), self.max_length) + with torch.device(self._device): + self._model.setup_caches(self.batch_size, max_seq_length) + logits = self._model(*input) + return logits + + def _model_generate(self, context, max_length, eos_token_id): + raise Exception('unimplemented') + + def run_eval(self, tasks, limit): + try: + lm_eval.tasks.initialize_tasks() + except: + pass + + task_dict = get_task_dict(tasks) + print("Evaluating Model On: ", task_dict) + with torch.no_grad(): + result = evaluate( + self, + task_dict, + limit=limit, + ) + for task, res in result["results"].items(): + print(f"{task}: {res}") + return result class MultiInput: @@ -254,7 +327,7 @@ def __init__( self.groupsize = groupsize self.inputs = inputs self.gptq_done = False - self.debug = True + self.debug = False def configure_quantization_mode( self, @@ -290,7 +363,10 @@ def configure_quantization_mode( # note any final packing for storage should happen here # `act_fake_quant_func` - self.act_fake_quant_func = act_fake_quant_func # accepts [activation tensor], returns a fake-quantized activation tensor + if act_fake_quant_func is None: + self.act_fake_quant_func = lambda x: x + else: + self.act_fake_quant_func = act_fake_quant_func # accepts [activation tensor], returns a fake-quantized activation tensor return self def run(self): @@ -314,7 +390,7 @@ def get_quantized_state_dict(self): quantized_state_dict.pop(param_fqn) return quantized_state_dict - def call_function(self, target, args, kwargs, skip_quant=False): # noqa: C901 + def call_function(self, target, args, kwargs, already_quantized=False): # noqa: C901 def tensors_to_cuda(args): new_args = [] @@ -354,9 +430,11 @@ def tensors_to_cuda(args): quantize_linear = ( (target == aten.linear.default) # if its a linear and id(args[1]) in self.id_to_name # and if we know the layer name - and not skip_quant # and if we weren't told to skip quantization + # and we haven't already quantized this layer + and not already_quantized # and if the skip_layer_func doesn't say we should skip and not (self.skip_layer_func is not None and self.skip_layer_func(args[1])) + ) # then we will quantize this linear layer/weight if quantize_linear: # instantiate variables for GPTQ @@ -371,8 +449,7 @@ def tensors_to_cuda(args): quantize_linear ): # calculate H instead of output (will run the linear eventually with updated weight) x = cur_args[0].float() - if self.act_fake_quant_func is not None: - x = self.act_fake_quant_func(x) + x = self.act_fake_quant_func(x) shape = x.shape n = 1 if len(shape) == 2 else shape[0] H *= total_batches / (total_batches + n) @@ -382,10 +459,13 @@ def tensors_to_cuda(args): ).t().float() H += x.matmul(x.t()) else: + # weight has already been quantized but still need to apply + # activation quant for final calculation + if already_quantized: + cur_args = (self.act_fake_quant_func(cur_args[0]), *cur_args[1:]) + # get output if its not a linear out = super().call_function(target, cur_args, cur_kwargs) - # if isinstance(out, torch.Tensor) and (out.isnan().max() or out.sum()==0 or out.isinf().max()): - # breakpoint() if isinstance(out, torch.Tensor): outputs.append(out.cpu()) else: @@ -412,12 +492,12 @@ def tensors_to_cuda(args): # run linear with new weight to get corrected output new_out = self.call_function( - target, (args[0], DQ, *args[2:]), kwargs, skip_quant=True + target, (args[0], DQ, *args[2:]), kwargs, already_quantized=True ) if self.debug: old_out = self.call_function( - target, (args[0][:2], args[1], *args[2:]), kwargs, skip_quant=True + target, (args[0][:2], args[1], *args[2:]), kwargs, already_quantized=True ) def SQNR(x, y): @@ -450,7 +530,7 @@ def SQNR(x, y): Q2 = self.quantize_func(W, qparams2) DQ2 = self.dequantize_func(Q2, qparams2).to(W.dtype) old_q_out = self.call_function( - target, (args[0][:2], DQ2, *args[2:]), kwargs, skip_quant=True + target, (args[0][:2], DQ2, *args[2:]), kwargs, already_quantized=True ) print( @@ -547,134 +627,294 @@ def faster_quant(self, H, W): return Q, DQ.to(orig_dtype), all_qparams -if TORCH_VERSION_AFTER_2_3: - from .quant_primitives import ( - get_group_qparams_symmetric, - group_quantize_tensor_symmetric, - per_token_dynamic_quant, - ) +class GPTQQuantizer(Quantizer): + """ + This class implements a GPTQ Quantizer that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. + Unlike the base Quantizer class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement + __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. + + The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and + create_quantized_state_dict. Here is a description of each function. + + get_qparams_func: + A function that calculates the quantization qparams for an input tensor. + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + qparams: it can have any format but will need to be handled by the other defined functions below. + + quantize_func: + A function that applies quantization to an input tensor. It should be noted + that this function needs to be able to handle quantizing the entire weight tensor, a single group, + or a single column. + Args: + weight: A 2d weight tensor with non-integer dtype. + qparams: the output from get_qparams_func + Returns: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + + + dequantize_func: + A function that dequantizes an input quantized weight tensor. It should be noted + that this function needs to be able to handle dequantizing the entire weight tensor, a single group, + or a single column. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + weight: A 2d weight tensor with non-integer dtype. + + act_fake_quant_func (optional): + A function that (dynamically) quantizes activation to input + Args: + input: input Tensor in f32/bf16/f16 + Returns: + output: dynamically quantized and dequantized Tensor (with the same dtype as input) + + combine_qparams_list_func: + A function that combines several qparams into one qparam. + Args: + qparams_list: a list of qparams objects, each obtained by calling get_qparams_func + on a single group from a weight tensor + Returns: + qparams: an object of the same format as the qparams above. + + skip_layer_func: + A function that determines which linear layers should be skipped during GPTQ + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + skip: boolean indicating whether layer should be skipped + + make_names_and_values_dict_func: + A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they + should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the + corresponding quantized weights and qparams. + """ - class GPTQQuantizer(Quantizer): - """ - This class implements a GPTQ Quantizer that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. - Unlike the base Quantizer class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement - __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. + def __init__(self): - The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and - create_quantized_state_dict. Here is a description of each function. + assert self.get_qparams_func is not None - get_qparams_func: - A function that calculates the quantization qparams for an input tensor. - Args: - weight: A 2d weight tensor with non-integer dtype. - Returns: - qparams: it can have any format but will need to be handled by the other defined functions below. + assert self.quantize_func is not None - quantize_func: - A function that applies quantization to an input tensor. It should be noted - that this function needs to be able to handle quantizing the entire weight tensor, a single group, - or a single column. - Args: - weight: A 2d weight tensor with non-integer dtype. - qparams: the output from get_qparams_func - Returns: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + assert self.dequantize_func is not None + assert self.combine_qparams_list_func is not None - dequantize_func: - A function that dequantizes an input quantized weight tensor. It should be noted - that this function needs to be able to handle dequantizing the entire weight tensor, a single group, - or a single column. - Args: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - qparams: the output from get_qparams_func - Returns: - weight: A 2d weight tensor with non-integer dtype. + # `make_names_and_values_dict_func`. + assert self.make_names_and_values_dict_func is not None - act_fake_quant_func (optional): - A function that (dynamically) quantizes activation to input - Args: - input: input Tensor in f32/bf16/f16 - Returns: - output: dynamically quantized and dequantized Tensor (with the same dtype as input) + @torch.no_grad() + def _create_quantized_state_dict( + self, + model, + inputs, + blocksize, + percdamp, + groupsize, + # `typing.Dict[, ]` to avoid runtime subscripting errors. + ) -> Dict: + print("Tracing model for GPTQ") + GPTQ_runner = GenericGPTQRunner( + model, + inputs, + blocksize, + percdamp, + groupsize, + ).configure_quantization_mode( + self.get_qparams_func, # pyre-ignore[16] + self.quantize_func, # pyre-ignore[16] + self.dequantize_func, # pyre-ignore[16] + self.combine_qparams_list_func, # pyre-ignore[16] + self.make_names_and_values_dict_func, # pyre-ignore[16] + self.skip_layer_func, # pyre-ignore[16] + self.act_fake_quant_func if hasattr(self, "act_fake_quant_func") else None, # pyre-ignore[16] + ) + print("Applying GPTQ to weights") + GPTQ_runner.run() + return GPTQ_runner.get_quantized_state_dict() - combine_qparams_list_func: - A function that combines several qparams into one qparam. - Args: - qparams_list: a list of qparams objects, each obtained by calling get_qparams_func - on a single group from a weight tensor - Returns: - qparams: an object of the same format as the qparams above. + def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module": + raise NotImplementedError("_convert_for_runtime not implemented") - skip_layer_func: - A function that determines which linear layers should be skipped during GPTQ - Args: - weight: A 2d weight tensor with non-integer dtype. - Returns: - skip: boolean indicating whether layer should be skipped + @torch.no_grad() + def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module: + pass - make_names_and_values_dict_func: - A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they - should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. - Args: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - qparams: the output from get_qparams_func - Returns: - names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the - corresponding quantized weights and qparams. - """ - def __init__(self): +def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c - assert self.get_qparams_func is not None +class WeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: torch.Tensor - assert self.quantize_func is not None + def __init__( + self, in_features: int, out_features: int, + bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True, + ) -> None: + super().__init__() + self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles) + if self.padding: + from model import find_multiple + self.origin_in_features = in_features + in_features = find_multiple(in_features, 1024) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles - assert self.dequantize_func is not None + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" + if use_cuda: + self.register_buffer( + "weight", + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + ) + else: + self.register_buffer( + "weight", + torch.empty((out_features, in_features // 2), dtype=torch.uint8) + ) + self.register_buffer( + "scales_and_zeros", + torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) + ) - assert self.combine_qparams_list_func is not None + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(torch.bfloat16) + if self.padding: + import torch.nn.functional as F + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_int4( + input, + self.weight, self.scales_and_zeros, self.out_features, self.groupsize + ) - # `make_names_and_values_dict_func`. - assert self.make_names_and_values_dict_func is not None - @torch.no_grad() - def _create_quantized_state_dict( +def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): + k_divisible_by_groupsize = k % groupsize == 0 + if inner_k_tiles is not None: + k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0 + return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles + return k_divisible_by_groupsize + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=True, skip_layer_func = None): + + for name, child in module.named_children(): + if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)): + if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: + setattr(module, name, WeightOnlyInt4Linear( + child.in_features, child.out_features, bias=False, + groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda + )) + else: + replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda, skip_layer_func) + +class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer): + def __init__( self, - model, - inputs, blocksize, percdamp, groupsize, - # `typing.Dict[, ]` to avoid runtime subscripting errors. - ) -> Dict: - print("Tracing model for GPTQ") - GPTQ_runner = GenericGPTQRunner( - model, - inputs, - blocksize, - percdamp, + inner_k_tiles=8, + padding_allowed=True, + ): + self.blocksize = blocksize + self.percdamp = percdamp + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding_allowed = padding_allowed + self.act_fake_quant_func = None + n_bit = 4 + self.get_qparams_func = lambda w: get_groupwise_affine_qparams( + w, n_bit, groupsize + ) + self.quantize_func = lambda w, qparams: groupwise_affine_quantize_tensor_from_qparams( + w, qparams[0], qparams[1], n_bit, groupsize + ) + self.dequantize_func = lambda q, qparams: groupwise_affine_dequantize_tensor_from_qparams( + q, + qparams[0], + qparams[1], + n_bit, groupsize, - ).configure_quantization_mode( - self.get_qparams_func, # pyre-ignore[16] - self.quantize_func, # pyre-ignore[16] - self.dequantize_func, # pyre-ignore[16] - self.combine_qparams_list_func, # pyre-ignore[16] - self.make_names_and_values_dict_func, # pyre-ignore[16] - self.skip_layer_func, # pyre-ignore[16] - self.act_fake_quant_func if hasattr(self, "act_fake_quant_func") else None, # pyre-ignore[16] ) - print("Applying GPTQ to weights") - GPTQ_runner.run() - return GPTQ_runner.get_quantized_state_dict() + self.combine_qparams_list_func = lambda qparams_list: [ + torch.cat(x, dim=1) for x in zip(*qparams_list) + ] + # skip unless padding_allowed=True or its correctly sized + self.skip_layer_func = lambda linear_weight: not ( + _check_linear_int4_k(linear_weight.shape[-1], groupsize) or padding_allowed + ) - def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module": - raise NotImplementedError("_convert_for_runtime not implemented") + # we need to do the padding here, both for q and the qparams if necessary + + # TODO: this is the gpt-fast version, merge with the main version later + def make_names_and_values_dict_func(q, qparams): + k = q.shape[1] + new_k = find_multiple(k, 1024) + # how much we need to pad the weight + delta_k = new_k - q.shape[1] + q = q.to(torch.int32) + final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) + scales = qparams[0].to(torch.bfloat16) + zeros = qparams[1].to(torch.bfloat16) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + # how many new groups we need for padded weight + delta_groups = new_k // groupsize - scales_and_zeros.shape[0] + final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) + return {"weight": final_q, "scales_and_zeros": final_s_and_z} + + self.make_names_and_values_dict_func = make_names_and_values_dict_func + super().__init__() + + def _convert_for_runtime(self, model): + # TODO: temporary path for gpt-fast, will remove later + replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + self.padding_allowed, + skip_layer_func = self.skip_layer_func, + ) + return model - @torch.no_grad() def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module: - pass + state_dict = self._create_quantized_state_dict( + model, + inputs, + self.blocksize, + self.percdamp, + self.groupsize, + ) + model = self._convert_for_runtime(model) + model.load_state_dict(state_dict, strict=False) + return model +if TORCH_VERSION_AFTER_2_3: + from .quant_primitives import ( + get_group_qparams_symmetric, + group_quantize_tensor_symmetric, + per_token_dynamic_quant, + ) + def linear_forward_8da4w( x, weight_int8, @@ -714,58 +954,6 @@ def linear_forward_8da4w( return c - - class WeightOnlyInt4Linear(torch.nn.Module): - __constants__ = ['in_features', 'out_features'] - in_features: int - out_features: int - weight: torch.Tensor - - def __init__( - self, in_features: int, out_features: int, - bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True, - ) -> None: - super().__init__() - self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles) - if self.padding: - from model import find_multiple - self.origin_in_features = in_features - in_features = find_multiple(in_features, 1024) - - self.in_features = in_features - self.out_features = out_features - assert not bias, "require bias=False" - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - - assert out_features % 8 == 0, "require out_features % 8 == 0" - assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" - if use_cuda: - self.register_buffer( - "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) - ) - else: - self.register_buffer( - "weight", - torch.empty((out_features, in_features // 2), dtype=torch.uint8) - ) - self.register_buffer( - "scales_and_zeros", - torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = input.to(torch.bfloat16) - if self.padding: - import torch.nn.functional as F - input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) - return linear_forward_int4( - input, - self.weight, self.scales_and_zeros, self.out_features, self.groupsize - ) - - class Int8DynActInt4WeightLinear(torch.nn.Module): __constants__ = ["in_features", "out_features"] @@ -847,62 +1035,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ) - def find_multiple(n: int, *args: Tuple[int]) -> int: - k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9] - if n % k == 0: - return n - return n + k - (n % k) - - def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): - k_divisible_by_groupsize = k % groupsize == 0 - if inner_k_tiles is not None: - k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0 - return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles - return k_divisible_by_groupsize - - def _calc_padded_size_linear_int4(k, groupsize=1): - return find_multiple(k, groupsize) - - def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): - origin_x_size = x.size() - x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) - new_shape = origin_x_size[:-1] + (out_features,) - c = c.reshape(new_shape) - return c - - def pack_scales_and_zeros(scales, zeros, precision=torch.float32): - assert scales.shape == zeros.shape - assert scales.dtype == precision - assert zeros.dtype == precision - return ( - torch.cat( - [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), - ], - 2, - ) - .transpose(0, 1) - .contiguous() - ) - - def unpack_scales_and_zeros(scales_and_zeros): - assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - assert scales_and_zeros.dtype == torch.float - return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) - - def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda): - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: - setattr(module, name, WeightOnlyInt4Linear( - child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda - )) - else: - replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda) - def replace_linear_8da4w( module, groupsize, @@ -934,23 +1066,6 @@ def replace_linear_8da4w( scales_precision, ) - def pack_scales_and_zeros(scales, zeros): - assert scales.shape == zeros.shape - assert scales.dtype == torch.bfloat16 - assert zeros.dtype == torch.bfloat16 - return ( - torch.cat( - [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), - ], - 2, - ) - .transpose(0, 1) - .contiguous() - ) - - class Int8DynActInt4WeightQuantizer(Quantizer): def __init__( self, @@ -998,7 +1113,7 @@ def _create_quantized_state_dict( in_features, self.groupsize, self.inner_k_tiles ): if self.padding_allowed: - from model import find_multiple + from .utils import find_multiple import torch.nn.functional as F print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") padded_in_features = find_multiple(in_features, 1024) @@ -1019,7 +1134,7 @@ def _create_quantized_state_dict( ) if self._is_gpt_fast: weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int8.to(torch.int32), self.inner_k_tiles) - scales_and_zeros = pack_scales_and_zeros(scales, zeros) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") else: @@ -1060,27 +1175,6 @@ def quantize( return model - # TODO: consolidate with other quantizers - class Int4WeightQuantizer(Quantizer): - def __init__( - self, - groupsize: int = 256, - padding_allowed: bool = False, - precision: torch.dtype = torch.float32, - inner_k_tiles: Optional[int] = None, - _use_cuda: bool = True, - ) -> None: - super().__init__( - groupsize, - padding_allowed, - precision, - torch.float32, # scales_precision - inner_k_tiles, - True, # _is_gpt_fast - _use_cuda, - ) - - class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer): def __init__( self, @@ -1146,7 +1240,7 @@ def make_names_and_values_dict_func_gpt_fast(q, qparams): final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) scales = qparams[0].to(torch.bfloat16) zeros = qparams[1].to(torch.bfloat16) - scales_and_zeros = pack_scales_and_zeros(scales, zeros) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) # how many new groups we need for padded weight delta_groups = new_k // groupsize - scales_and_zeros.shape[0] final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) @@ -1154,7 +1248,7 @@ def make_names_and_values_dict_func_gpt_fast(q, qparams): def make_names_and_values_dict_func(q, qparams): k = q.shape[1] - new_k = _calc_padded_size_linear_int4(k, groupsize) + new_k = find_multiple(k, 1 if groupsize is None else groupsize) # how much we need to pad the weight delta_k = new_k - q.shape[1] final_q = F.pad(q, pad=(0, delta_k)) @@ -1196,38 +1290,3 @@ def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: A model = self._convert_for_runtime(model) model.load_state_dict(state_dict, strict=False) return model - - - # TODO: consolidate with other quantizers - class Int4WeightGPTQQuantizer(Int8DynActInt4WeightGPTQQuantizer): - - def __init__( - self, - tokenizer, - blocksize, - percdamp, - groupsize, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs, - inner_k_tiles=8, - padding_allowed=True, - precision=torch.float32, - _use_cuda=True, - ): - super().__init__( - tokenizer, - blocksize, - percdamp, - groupsize, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs, - inner_k_tiles=inner_k_tiles, - padding_allowed=padding_allowed, - precision=precision, - _is_gpt_fast=_is_gpt_fast, - _use_cuda=_use_cuda, - ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a17daf8697..e7ce92976b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -30,6 +30,10 @@ ) from .weight_only import WeightOnlyInt8QuantLinear from .unified import Quantizer, TwoStepQuantizer +from .GPTQ import ( + Int4WeightOnlyGPTQQuantizer, +) + __all__ = [ "apply_weight_only_int8_quant", @@ -40,20 +44,19 @@ "swap_conv2d_1x1_to_linear", "Quantizer", "TwoStepQuantizer", + "Int4WeightOnlyGPTQQuantizer", ] if TORCH_VERSION_AFTER_2_3: from .GPTQ import ( Int8DynActInt4WeightQuantizer, Int8DynActInt4WeightGPTQQuantizer, - Int4WeightQuantizer, - Int4WeightGPTQQuantizer, + ) __all__ += [ "Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightGPTQQuantizer", - "Int4WeightQuantizer", - "Int4WeightGPTQQuantizer", + ] diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 5baa289729..8b6cc9cc7f 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -500,23 +500,6 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float ) -def pack_scales_and_zeros(scales, zeros, precision=torch.float16): - assert scales.shape == zeros.shape - assert scales.dtype == precision - assert zeros.dtype == precision - return ( - torch.cat( - [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), - ], - 2, - ) - .transpose(0, 1) - .contiguous() - ) - - if TORCH_VERSION_AFTER_2_3: def group_quantize_tensor_symmetric( w, @@ -591,4 +574,4 @@ def per_token_dynamic_quant(input: torch.Tensor) -> torch.Tensor: input = torch.ops.quantized_decomposed.dequantize_per_token( input, scales, zero_points, quant_min, quant_max, torch.int8, orig_dtype ) - return input + return input.to(orig_dtype) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 1f6b3a9bcf..a178edf125 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -3,11 +3,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import torch from torch.utils._python_dispatch import TorchDispatchMode from packaging import version +from functools import reduce +from math import gcd + __all__ = [ "find_multiple", @@ -18,7 +21,8 @@ ] -def find_multiple(n: int, k: int) -> int: +def find_multiple(n: int, *args: Tuple[int]) -> int: + k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9] if n % k == 0: return n return n + k - (n % k) From ec258e05c255ab996a47e1790ef32c1c492ca64c Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Thu, 4 Apr 2024 01:09:42 -0400 Subject: [PATCH 13/21] add int4 non-gptq and bugfixes (#119) Summary: int4weightlinear had a bug that made it not pad when it should have Test Plan: python test/quantization/test_quant_api.py -k "int4wo" Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 36 +++++++- torchao/quantization/GPTQ.py | 106 ++++++++++++++++++++--- torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 2 + torchao/quantization/quant_primitives.py | 1 - 5 files changed, 133 insertions(+), 14 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 700c9c9b98..5cc5ac1fa3 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -300,7 +300,6 @@ def test_gptq_quantizer_gpt_fast(self): @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_gptq_quantizer_int4wo(self): from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper - # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -357,6 +356,41 @@ def test_gptq_quantizer_int4wo(self): f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" ) + @unittest.skip("skipping until we get checkpoints for gpt-fast") + def test_quantizer_int4wo(self): + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer, TransformerEvalWrapper + precision = torch.bfloat16 + device = "cuda" + checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") + model = Transformer.from_name(checkpoint_path.parent.name) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + model.eval() + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor( # pyre-ignore[28] + model_file=str(tokenizer_path) + ) + groupsize = 128 + quantizer = Int4WeightOnlyQuantizer( + groupsize, + ) + model = quantizer.quantize(model).cuda() + result = TransformerEvalWrapper( + model, + tokenizer, + model.config.block_size, + prepare_inputs_for_model, + device, + ).run_eval( + ["wikitext"], + 1, + ) + assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( + f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" + ) + @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper(self): from torchao.quantization.GPTQ import TransformerEvalWrapper diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 559ab54f7d..d648507085 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -28,6 +28,7 @@ groupwise_affine_quantize_tensor_from_qparams, groupwise_affine_dequantize_tensor_from_qparams, pack_tinygemm_scales_and_zeros, + groupwise_affine_quantize_tensor, ) aten = torch.ops.aten @@ -65,8 +66,8 @@ __all__ = [ "MultiInput", - "WeightOnlyInt4Linear", "Int4WeightOnlyGPTQQuantizer", + "Int4WeightOnlyQuantizer", ] + add_ons if lm_eval_available: @@ -117,7 +118,10 @@ def __init__( @property def eot_token_id(self): - return self._tokenizer.eos_id() + try: + return self._tokenizer.eos_id() + except: + return self._tokenizer.eos_id @property def max_length(self): @@ -139,7 +143,10 @@ def tok_encode(self, string: str, **kwargs): # TODO: verify this for multi-batch as well tokens = self._tokenizer.encode(string) if hasattr(self._tokenizer, "bos_id"): - tokens = [self._tokenizer.bos_id()] + tokens + try: + tokens = [self._tokenizer.bos_id()] + tokens + except: + tokens = [self._tokenizer.bos_id] + tokens return tokens def tok_decode(self, tokens): @@ -747,6 +754,12 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module": def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module: pass +def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): + k_divisible_by_groupsize = k % groupsize == 0 + if inner_k_tiles is not None: + k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0 + return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles + return k_divisible_by_groupsize def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): origin_x_size = x.size() @@ -767,7 +780,7 @@ def __init__( bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True, ) -> None: super().__init__() - self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles) + self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) if self.padding: from model import find_multiple self.origin_in_features = in_features @@ -806,14 +819,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.weight, self.scales_and_zeros, self.out_features, self.groupsize ) - -def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): - k_divisible_by_groupsize = k % groupsize == 0 - if inner_k_tiles is not None: - k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0 - return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles - return k_divisible_by_groupsize - def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=True, skip_layer_func = None): for name, child in module.named_children(): @@ -826,6 +831,83 @@ def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_c else: replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda, skip_layer_func) +class Int4WeightOnlyQuantizer(Quantizer): + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = True, + inner_k_tiles: Optional[int] = 8, + ) -> None: + super().__init__() + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + + self.inner_k_tiles = inner_k_tiles + self.groupsize: int = groupsize + self.padding_allowed: bool = padding_allowed + + @torch.no_grad() + def _create_quantized_state_dict( + self, model: torch.nn.Module + ) -> Dict[str, torch.Tensor]: + cur_state_dict = model.state_dict() + for fqn, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + # assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + assert ( + in_features % self.groupsize == 0 + ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0" + + weight = mod.weight.data + if not _check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): + if self.padding_allowed: + from .utils import find_multiple + import torch.nn.functional as F + print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + else: + print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it") + continue + ( + w_int4x8, + scales_and_zeros + ) = groupwise_affine_quantize_tensor( + weight, + 4, # n_bit + self.groupsize, + ) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to("cuda"), self.inner_k_tiles) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cuda") + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cuda") + return cur_state_dict + + def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: + replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + self.padding_allowed, + ) + return model + + def quantize( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + state_dict = self._create_quantized_state_dict(model) + model = self._convert_for_runtime(model) + # TODO: make it strict + model.load_state_dict(state_dict, strict=False) + return model + class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer): def __init__( self, diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 4bfb279769..12aa70039b 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -42,4 +42,6 @@ "compute_error", "get_model_size_in_bytes", "WeightOnlyInt8QuantLinear", + "Int4WeightOnlyGPTQQuantizer", + "Int4WeightOnlyQuantizer", ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e7ce92976b..581e312927 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -32,6 +32,7 @@ from .unified import Quantizer, TwoStepQuantizer from .GPTQ import ( Int4WeightOnlyGPTQQuantizer, + Int4WeightOnlyQuantizer, ) @@ -45,6 +46,7 @@ "Quantizer", "TwoStepQuantizer", "Int4WeightOnlyGPTQQuantizer", + "Int4WeightOnlyQuantizer" ] if TORCH_VERSION_AFTER_2_3: diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 8b6cc9cc7f..88eafd4b2a 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -383,7 +383,6 @@ def pack_tinygemm_scales_and_zeros(scales, zeros): def unpack_tinygemm_scales_and_zeros(scales_and_zeros): assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - assert scales_and_zeros.dtype == torch.float return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) From eba4c368fdc0f30c0c8316d7b11c75b3421fd01a Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Thu, 4 Apr 2024 01:23:07 -0400 Subject: [PATCH 14/21] fixing bug in GPTQ (#120) * fixing bug in GPTQ Summary: shape was always padded even when not needed. Test Plan: pythont test/quantization/test_quant_api.py -k "test_gptq_quantizer_int4wo" Reviewers: Subscribers: Tasks: Tags: * removing extra spaces Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/GPTQ.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index d648507085..bd82f36092 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -950,7 +950,10 @@ def __init__( # TODO: this is the gpt-fast version, merge with the main version later def make_names_and_values_dict_func(q, qparams): k = q.shape[1] - new_k = find_multiple(k, 1024) + if not _check_linear_int4_k(k, groupsize): + new_k = find_multiple(k, 1024) + else: + new_k = k # how much we need to pad the weight delta_k = new_k - q.shape[1] q = q.to(torch.int32) From 76e2ef59d11ca80a2de0b57d37bd93a66d914ad8 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 4 Apr 2024 00:10:42 -0700 Subject: [PATCH 15/21] Support `model.to` int8 weight only quantized model (#122) Summary: registering fields as buffers so they get picked up in `model.to` Test Plan: python test/quantization/test_quant_api.py -k test_int8_wo_quant_save_load Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 26 ++++++++++++++++++++++++-- torchao/quantization/weight_only.py | 5 ++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 5cc5ac1fa3..d772fda831 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -8,6 +8,7 @@ # This test takes a long time to run import unittest import torch +import os from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import ( prepare_pt2e, @@ -18,9 +19,10 @@ get_symmetric_quantization_config, ) -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.quantization.quant_api import apply_dynamic_quant from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, + apply_dynamic_quant, + apply_weight_only_int8_quant, Quantizer, TwoStepQuantizer, ) @@ -137,6 +139,26 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): compiled = m(*example_inputs) torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_int8_wo_quant_save_load(self): + m = M().eval().cpu() + apply_weight_only_int8_quant(m) + example_inputs = m.example_inputs() + ref = m(*example_inputs) + _TMP_FN = "_test.pt" + torch.save(m.state_dict(), _TMP_FN) + + state_dict = torch.load(_TMP_FN) + os.remove(_TMP_FN) + m2 = M().eval() + apply_weight_only_int8_quant(m2) + m2.load_state_dict(state_dict) + m2 = m2.to(device="cuda") + example_inputs = map(lambda x: x.cuda(), example_inputs) + res = m2(*example_inputs) + + torch.testing.assert_close(ref, res.cpu()) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_8da4w_quantizer(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer diff --git a/torchao/quantization/weight_only.py b/torchao/quantization/weight_only.py index 2ab5adf3d1..099df0f17f 100644 --- a/torchao/quantization/weight_only.py +++ b/torchao/quantization/weight_only.py @@ -22,9 +22,8 @@ def __init__(self, *args, **kwargs): scales = kwargs.pop("scales") super().__init__(*args, **kwargs) - self.w_int8 = w_int8 - - self.scales = scales + self.register_buffer("w_int8", w_int8) + self.register_buffer("scales", scales) def forward(self, x, *args, **kwargs): """ From 8713b7d201578dd7d58f9cd87f3f74d277dc3c1a Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 4 Apr 2024 12:01:19 -0700 Subject: [PATCH 16/21] torchao v0.1 version bump (#125) Summary: att Test Plan: / Reviewers: Subscribers: Tasks: Tags: --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e1cf9314a6..378339fa16 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ def read_requirements(file_path): package_name = "torchao-nightly" if os.environ.get("TORCHAO_NIGHTLY") else "torchao" # Version is year.month.date if using nightlies -version = current_date if package_name == "torchao-nightly" else "0.0.3" +version = current_date if package_name == "torchao-nightly" else "0.1" setup( From fc5d2c89915bbffc9a189d06cdc46537a12ef2a3 Mon Sep 17 00:00:00 2001 From: Xia Weiwen Date: Fri, 5 Apr 2024 11:43:16 -0700 Subject: [PATCH 17/21] Support int_scaled_mm on CPU (#121) --- test/kernel/test_autotuner.py | 8 +++++--- torchao/kernel/intmm_triton.py | 6 ++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index 33678995af..82fb117363 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -52,13 +52,15 @@ def test_int_mm(self, device, dtype): @parameterized.expand( [ ("cuda", torch.bfloat16), - # TODO: ("cpu", torch.bfloat16), + ("cpu", torch.bfloat16), ("cuda", torch.float16), - # TODO: ("cpu", torch.float16), + ("cpu", torch.float16), ] ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int_scaled_mm(self, device, dtype): + if device == "cuda" and not torch.cuda.is_available(): + self.skipTest(f"{device} not available") + from torchao.kernel import intmm dtype = torch.bfloat16 diff --git a/torchao/kernel/intmm_triton.py b/torchao/kernel/intmm_triton.py index 4e84d9cd3c..d10dac0abe 100644 --- a/torchao/kernel/intmm_triton.py +++ b/torchao/kernel/intmm_triton.py @@ -356,3 +356,9 @@ def int_scaled_matmul_cuda(a, b, scales1): int_scaled_matmul_kernel, [a, b, scales1, c], int8_mm_kernel_configs ) return int_scaled_matmul_kernel(a, b, scales1, c, best_config) + + +@torch.library.impl(lib, "int_scaled_matmul", "CPU") +def int_scaled_matmul_cpu(a, b, scales1): + c = torch._int_mm(a, b) + return c.to(scales1.dtype) * scales1 From c40358072f99b50cd7e58ec11e0e8d90440e3e25 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Fri, 5 Apr 2024 15:25:52 -0700 Subject: [PATCH 18/21] Reapply Autoquant (#82) (#109) --- .github/workflows/regression_test.yml | 2 +- README.md | 33 ++- test/integration/test_integration.py | 113 ++++++++ torchao/__init__.py | 13 +- torchao/quantization/__init__.py | 4 + torchao/quantization/autoquant.py | 390 ++++++++++++++++++++++++++ torchao/quantization/quant_api.py | 7 +- torchao/quantization/subclass.py | 2 +- 8 files changed, 546 insertions(+), 18 deletions(-) create mode 100644 torchao/quantization/autoquant.py diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 560df57c3d..0a53eb911e 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -53,4 +53,4 @@ jobs: - name: Run tests run: | - pytest test --verbose -s -x + pytest test --verbose -s diff --git a/README.md b/README.md index 8adcd9c24c..00609bf298 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# torchao: PyTorch Architecture Optimization +# torchao: PyTorch Architecture Optimization **Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue** -The `torchao` package allows you to quantize and prune your models using native PyTorch. +The `torchao` package allows you to quantize and prune your models using native PyTorch. The repo hosts both 1. lower precision [dtypes](./torchao/dtypes) such as nf4, uint4 @@ -38,30 +38,46 @@ pip install -e . Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change. -### A8W8 Dynamic Quantization +### Autoquantization -```Python +The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes +of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer. + +```python import torch -from torchao.quantization import quant_api +import torchao -# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor +# inductor settings which improve torch.compile performance for quantized modules torch._inductor.config.force_fuse_int_mm_with_mul = True +torch._inductor.config.use_mixed_mm = True # Plug in your model and example input model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') -# convert linear modules to quantized linear modules -quant_api.change_linear_weights_to_int8_dqtensors(model) +# perform autoquantization +torchao.autoquant(model, (input)) # compile the model to improve performance model = torch.compile(model, mode='max-autotune') model(input) ``` + +### A8W8 Dynamic Quantization + +```python +# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor +torch._inductor.config.force_fuse_int_mm_with_mul = True +from torchao.quantization import quant_api +# convert linear modules to quantized tensor subclasses +quant_api.change_linear_weights_to_int8_dqtensors(model) +``` + ### A16W8 WeightOnly Quantization ```python +from torchao.quantization import quant_api quant_api.change_linear_weights_to_int8_woqtensors(model) ``` @@ -71,6 +87,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is ### A16W4 WeightOnly Quantization ```python +from torchao.quantization import quant_api quant_api.change_linear_weights_to_int4_woqtensors(model) ``` diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 801ca10bc2..2425d341e2 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -7,11 +7,13 @@ # mypy: ignore-errors import copy import unittest +import itertools import torch import torch.nn as nn from torch._inductor.utils import run_and_get_code from torch._dynamo import config +import torchao from torch.ao.quantization import MinMaxObserver, QConfigMapping from torchao.quantization.dynamic_quant import ( @@ -54,6 +56,13 @@ _fqn_to_op_to_shape_to_count, LoggingTensorMode, ) +from torchao.quantization.autoquant import ( + AQInt8DynamicallyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight2, + AQWeightOnlyQuantizedLinearWeight3 + +) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os from parameterized import parameterized @@ -71,6 +80,12 @@ ("cuda", torch.bfloat16), ] +def combine_parameters(a, b): + new_tuples = [] + for (tuple1, tuple2) in itertools.product(a, b): + new_tuples.append(tuple1 + tuple2) + return new_tuples + def run_supported_device_dtype(test_method): def wrapper(*args, **kwargs): if args[2] == "cuda" and not torch.cuda.is_available(): @@ -907,6 +922,36 @@ def test_int8_weight_only_quant_subclass(self, device, dtype): Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype ) + @parameterized.expand(COMMON_DEVICE_DTYPE) + def test_aq_int8_dynamic_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype + ) + + @parameterized.expand(COMMON_DEVICE_DTYPE) + def test_aq_int8_weight_only_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype + ) + + @parameterized.expand(COMMON_DEVICE_DTYPE) + def test_aq_int8_weight_only_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype + ) + + @parameterized.expand(COMMON_DEVICE_DTYPE) + def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype + ) + + @parameterized.expand(COMMON_DEVICE_DTYPE) + def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype + ) + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_int4_weight_only_quant_subclass(self, device, dtype): @@ -1290,6 +1335,74 @@ def test_on_dummy_distilbert(self): print("sqnr_pt_quant", sqnr_pt_quant) self.assertTrue(sqnr_sq >= 8.0) +class TestAutoQuant(unittest.TestCase): + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, + [ + (16, 128, 128), + (64, 128, 128), + # (2**15, 128, 128), TODO: Runs out of shared memory on T4 + (16, 128, 256), + # (64, 128, 256), # TODO: Runs out of shared memory on T4 + (16, 256, 128), + (64, 256, 128), + # (256, 256, 128), TODO: Runs out of shared memory on T4 + ])) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") + def test_autoquant_one_input(self, device, dtype, m, k, n): + print("(m, k, n): ", (m, k, n)) + if device != "cuda" or not torch.cuda.is_available(): + self.skipTest(f"autoquant currently does not support {device}") + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): + if dtype == torch.bfloat16: + self.skipTest(f"bfloat16 requires sm80+") + if m == 1: + self.skipTest(f"Shape {(m, k, n)} requires sm80+") + torch._inductor.config.epilogue_fusion = False + torch._inductor.config.use_mixed_mm = True + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._dynamo.config.automatic_dynamic_shapes = False + + example_input = torch.randn(m, k, device=device, dtype=dtype) + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to(device).to(dtype) + out = model(example_input) + torchao.autoquant(model, example_input) + out2 = model(example_input) + sqnr = SQNR(out, out2) + self.assertTrue(sqnr >= 30) + + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, + [ + (1, 1, 128, 128), + (1, 32, 128, 128), + (32, 32, 128, 128), + ])) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") + def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n): + if device != "cuda" or not torch.cuda.is_available(): + self.skipTest(f"autoquant currently does not support {device}") + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): + if dtype == torch.bfloat16: + self.skipTest(f"bfloat16 requires sm80+") + if m1 == 1 or m2 == 1: + self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to(device).to(dtype) + example_input = torch.randn(m1, k, device=device, dtype=dtype) + example_input2 = torch.randn(m2, k, device=device, dtype=dtype) + torchao.quantization.change_linears_to_autoquantizable(model) + out=model(example_input) + model(example_input2) + torchao.quantization.change_autoquantizable_to_quantized(model) + out2 = model(example_input) + sqnr = SQNR(out, out2) + self.assertTrue(sqnr >= 30) if __name__ == "__main__": unittest.main() diff --git a/torchao/__init__.py b/torchao/__init__.py index b0dc9b1e1e..ecd2ccf4b9 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -1,8 +1,13 @@ +from torchao.quantization import ( + apply_weight_only_int8_quant, + apply_dynamic_quant, + autoquant, +) from . import dtypes -from .quantization.quant_api import apply_dynamic_quant -from .quantization.quant_api import apply_weight_only_int8_quant __all__ = [ - "dtypes", - "apply_dynamic_quant", + "dtypes", + "apply_dynamic_quant", + "apply_weight_only_int8_quant", + "autoquant", ] diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 12aa70039b..ab51dbb3a5 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -11,6 +11,7 @@ from .utils import * # noqa: F403 from .weight_only import * # noqa: F403 from .unified import * +from .autoquant import * __all__ = [ "DynamicallyPerAxisQuantizedLinear", @@ -26,6 +27,9 @@ "dynamically_quantize_per_channel", "dequantize_per_tensor", "dequantize_per_channel", + "autoquant", + "change_linears_to_autoquantizable", + "change_autoquantizable_to_quantized", "quant_int8_dynamic_linear", "quant_int8_matmul", "quant_int8_dynamic_per_token_linear", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py new file mode 100644 index 0000000000..f1f387d7b5 --- /dev/null +++ b/torchao/quantization/autoquant.py @@ -0,0 +1,390 @@ +import torch +from .subclass import ( # noqa + Int8DynamicallyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, + QuantizedLinearWeightBase, +) +from torch.utils._python_dispatch import return_and_correct_aliasing +from .quant_primitives import ( + quantize_activation_per_token_absmax, + safe_int_mm, +) +import torch.nn.functional as F +from torch._inductor.utils import do_bench +aten = torch.ops.aten + +AUTOQUANT_CACHE = {} + +def check_cache(cls, shapes_and_dtype): + return AUTOQUANT_CACHE.get((cls,)+shapes_and_dtype, None) + +def update_cache(cls, shapes_and_dtype, res): + AUTOQUANT_CACHE[(cls,)+shapes_and_dtype] = res + +class AutoQuantizableLinearWeight(torch.Tensor): + """ + when run, finds best type of quantization for this tensor and swaps itself with that + """ + @staticmethod + def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): + kwargs["device"] = weight.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else weight.layout + ) + kwargs["dtype"] = ( + kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype + ) + kwargs["requires_grad"] = False + shape = kwargs.pop("shape", weight.shape) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): + self.weight = weight + self.qtensor_class_list = qtensor_class_list + self.logged_data = {} + self.mode = mode + + def __repr__(self): + return ( + f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})" + ) + + @staticmethod + def log_shape(act_mat, w_autoquant, bias): + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + logged_dtype = act_mat.dtype + logged_shapes = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape,) + shapes_and_dtype = logged_shapes + (logged_dtype,) + w_autoquant.logged_data[shapes_and_dtype] = 1 + w_autoquant.logged_data.get(shapes_and_dtype, 0) + for q_cls in w_autoquant.qtensor_class_list: + if check_cache(q_cls, shapes_and_dtype) is None: + update_cache(q_cls, shapes_and_dtype, None) + + def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): + act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype + if check_cache(q_cls, shapes_and_dtype) is None: + with torch.no_grad(): + act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device) + bias = None if bias_shape is None else torch.randn(bias_shape, dtype=act_dtype, device=self.device) + res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode) + update_cache(q_cls, shapes_and_dtype, res) + + def to_quantized(self, error_on_unseen, **kwargs): + if error_on_unseen and self.logged_data == {}: + raise RuntimeError("must run module normally to get shape, dtype info for autoquant") + elif (self.logged_data == {}) and not error_on_unseen: + # default back to non-quantized weight if not seen + self = AQFloatLinearWeight.from_float(self.weight) + return self + + + # only want to do shape+final print a single time if multiple layers + # see/have same shapes so we gate on check_cache being empty for + # at least one of the class/shape combinations. + do_final_print = False + print_once = True + + def count_shapes(self, do_print=True): + differe_shape_count=0 + for shapes_and_dtype, times_seen in self.logged_data.items(): + differe_shape_count += 1 + if do_print: + act_shape, weight_shape, bias_shape, dtype = shapes_and_dtype + print(f"activation_shapes: {act_shape}, times_seen: {times_seen}") + if do_print: + print(f"weight_shape: {weight_shape}, dtype: {dtype}, bias_shape: {bias_shape}") + return differe_shape_count + + # check each class + best_time = torch.inf + best_cls = None + for q_cls in self.qtensor_class_list: + # for each logged shape+dtype, benchmark + cur_time=0 + shape_count = count_shapes(self, do_print=False) + for shapes_and_dtype, times_seen in self.logged_data.items(): + if check_cache(q_cls, shapes_and_dtype) is None: + # only do final print if we have to autotune at least one cls/shape pair + do_final_print=True + + # only print shapes once + if print_once == True: + print_once = False + count_shapes(self, do_print=True) + + time_for_best_shape = check_cache(best_cls, shapes_and_dtype) + time_for_best_shape = torch.inf if time_for_best_shape is None else time_for_best_shape + self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape) + torch._dynamo.reset() + cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen + if shape_count is not None and shape_count > 1: + print(f">total_time: {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms") + if best_time >= cur_time: + best_time = cur_time + best_cls = q_cls + # only print if this is the first time seeing some cls+shape combo, + # otherwise we will print the same thing for every layer. + if do_final_print: + print(f"best_cls={best_cls}\n") + # TODO handle random cls args/kwargs? or should they be curried? + self = best_cls.from_float(self.weight) + return self + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), self.qtensor_class_list, dtype=self.dtype, mode=self.mode + ) + + def __tensor_flatten__(self): + return ["weight"], [self.qtensor_class_list, self.mode, self.dtype, self.shape] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + weight = tensor_data_dict["weight"] + qtensor_class_list, mode, dtype, shape = tensor_attributes[0] + return cls(weight, qtensor_class_list, mode, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) + + @classmethod + def from_float(cls, weight, qtensor_class_list, **kwargs): + return cls(weight, qtensor_class_list, **kwargs) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func is torch.nn.functional.linear: + mat1, w_autoquant, bias = ( + args[0], + args[1], + args[2] if len(args)>2 else None + ) + cls.log_shape(mat1, w_autoquant, bias) + return func(mat1, w_autoquant.weight, bias) + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except: + print(f"ERR: subclass doesn't implement {func}") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func is aten.detach.default: + return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) + +def do_autoquant_bench(op, *args, **kwargs): + """ + runs benchmark op(*args, **kwargs) avoiding torch.compile overhead + """ + rep = kwargs.pop("rep", 100) + warmup = kwargs.pop("warmup", 25) + with torch.no_grad(): + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + op(*args, **kwargs) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + op(*args, **kwargs) + res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") + return res + +def _is_interpolate_mode(mode): + if isinstance(mode, list) and mode[0]=="interpolate" and len(mode)==2 and isinstance(mode[1], float): + return True + return False + +class AQMixin(): + """ + Mixin to turn normal quantized subclasses into autoquantizable ones + """ + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + w_qtensor = cls.from_float(weight) + if _is_interpolate_mode(mode): + q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs") + else: + func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c)) + q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs") + res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100) + if res < best_time*1.1: + res2 = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900) + res=(res2*.9+res*.1) + print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") + return res + +class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight): + """ + AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight + """ + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + if not _is_interpolate_mode(mode): + return super()._autoquant_test(act_mat, weight, bias, best_time, mode) + + # SAM best is between .8 and 1, SDXL also performs best in this range + INTERPOLATION_CONSTANT = mode[1] + w_qtensor = cls.from_float(weight) + x_vals_int8, x_scales = quantize_activation_per_token_absmax( + act_mat.reshape(-1, act_mat.shape[-1]) + ) + quantized_matmul = ( + lambda x_vals_int8, x_scales, w_vals_int8: + safe_int_mm(x_vals_int8, w_vals_int8) * x_scales + ) + q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") + with torch.no_grad(): + res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) + print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") + + # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op + if res_matmul>=best_time: + return res_matmul + + # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT + to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul) + res = super()._autoquant_test(act_mat, weight, bias, to_beat) + max_int_const_win = (best_time-res_matmul)/(res-res_matmul) + res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul + print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") + return res_f + +class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight + """ + +class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that + uses a different kernel + """ + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + orig_dtype = act_mat.dtype + orig_shape = act_mat.shape + act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) + y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2) + y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales + if bias is not None: + y += bias + return y.to(orig_dtype) + + @classmethod + def _autoquant_test(cls, act_mat, *args): + # if act_mat has batchsize>2 don't use this kernel + if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>32: + return torch.inf + return super()._autoquant_test(act_mat, *args) + +class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that + uses a different kernel + """ + def _quantized_op(act_mat, w_qtensor, bias): + orig_shape = act_mat.shape + y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales) + y=y.reshape(*orig_shape[:-1], y.shape[-1]) + if bias is not None: + y += bias + return y + +class AQFloatLinearWeight(torch.Tensor, AQMixin): + """ + A class to be used in concert with AutoQuantizableLinearWeight to provide a + default/non-quantized option. Only implements the bare minimum needed to work with the + AutoQuantizableLinearWeight class using the same interfaces that would normally be + used by QTensor subclasses but for a default linear op instead. Result of from_float + is not a tensor subclass, but rather the float tensor. + """ + def __init__(self): + super().__init__() + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + return torch.nn.functional.linear(act_mat, w_qtensor, bias) + + @classmethod + def from_float(cls, weight): + return weight + +DEFAULT_CLASS_LIST = [ + AQFloatLinearWeight, + AQInt8DynamicallyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight2, + # AQWeightOnlyQuantizedLinearWeight3, + # TODO this gets picked in places where it makes perf worse, why? +] + +def change_linears_to_autoquantizable(model, **kwargs): + """ + Converts all linear weight tensors to the + AutoQuantizableLinearWeight tensor subclass. Expectation is that this is followed + by running the model and then calling change_autoquantizable_to_quantized + """ + from torchao.quantization.quant_api import _is_linear + filter_fn = kwargs.pop("filter_fn", _is_linear) + kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST) + kwargs["mode"] = kwargs.get("mode", ["relu", None]) + from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + from torchao.quantization.quant_api import _get_subclass_inserter + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs), + filter_fn if filter_fn is not None else _is_linear, + ) + +def change_autoquantizable_to_quantized(model, **kwargs): + """ + Converts AutoQuantizableLinearWeight tensor subclasses + to various quantized/non-quantized tensor subclasses depending + on benchmark results. Expectation is that these modules are + torch.compiled afterwards. + """ + hold = torch._dynamo.config.automatic_dynamic_shapes + torch._dynamo.config.automatic_dynamic_shapes = False + + filter_fn = kwargs.pop( + "filter_fn", + lambda mod, *args: + hasattr(mod, "weight") and isinstance(mod.weight, AutoQuantizableLinearWeight) + ) + error_on_unseen=kwargs.pop("error_on_unseen", True) + from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + from torchao.quantization.quant_api import _get_subclass_inserter + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter( + AutoQuantizableLinearWeight, method="to_quantized", error_on_unseen=error_on_unseen, **kwargs + ), + filter_fn, + ) + torch._dynamo.config.automatic_dynamic_shapes = hold + torch._dynamo.reset() + +@torch.no_grad() +def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **kwargs): + """ + Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape + across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer + and applies that type of quantization. + """ + if filter_fn is None: + from torchao.quantization.quant_api import _is_linear + filter_fn = _is_linear + + change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs) + if not isinstance(example_input, (tuple, list)): + assert isinstance(example_input, torch.Tensor) + example_input = [example_input] + model(*example_input) + change_autoquantizable_to_quantized(model, **kwargs) + return model + diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 581e312927..a830d52d78 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -58,7 +58,6 @@ __all__ += [ "Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightGPTQQuantizer", - ] @@ -119,10 +118,11 @@ def apply_dynamic_quant(model, filter_fn=None): def _get_subclass_inserter(cls, **kwargs): - + method = kwargs.pop("method", "from_float") def insert_subclass(lin): lin.weight = torch.nn.Parameter( - cls.from_float(lin.weight, **kwargs), requires_grad=False + # cls.from_float(...) + getattr(cls, method)(lin.weight, **kwargs), requires_grad=False ) return lin @@ -174,7 +174,6 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs): filter_fn, ) - def swap_conv2d_1x1_to_linear(model, filter_fn=None): """ Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized. diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 537099f67a..64689b8d95 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -301,7 +301,7 @@ def from_float(cls, input_float, qmin=-128, qmax=127): # however the external representation of our tensor will maintain the correct # shape attribute which needs to be tracked directly. int_data = w_int_repr.contiguous().t() - if cls is not Int8DynamicallyQuantizedLinearWeight: + if not issubclass(cls, Int8DynamicallyQuantizedLinearWeight): int_data = int_data.contiguous() return cls( int_data, w_scales, False, input_float.shape, dtype=input_float.dtype From b9beaf351e27133d189b57d6fa725b1a7824a457 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 12 Apr 2024 10:32:01 -0700 Subject: [PATCH 19/21] Allow cpu and gpu in int4wo and int4wo-gptq quantizer (#131) Summary: att Test Plan: verified in torchat Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/GPTQ.py | 57 ++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index bd82f36092..91e604c8cf 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -764,7 +764,12 @@ def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) + c = torch.ops.aten._weight_int4pack_mm( + x.to(torch.bfloat16), + weight_int4pack, + groupsize, + scales_and_zeros.to(torch.bfloat16) + ).to(dtype=x.dtype) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) return c @@ -776,8 +781,8 @@ class WeightOnlyInt4Linear(torch.nn.Module): weight: torch.Tensor def __init__( - self, in_features: int, out_features: int, - bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True, + self, in_features: int, out_features: int, + bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, ) -> None: super().__init__() self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) @@ -794,23 +799,16 @@ def __init__( assert out_features % 8 == 0, "require out_features % 8 == 0" assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" - if use_cuda: - self.register_buffer( - "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) - ) - else: - self.register_buffer( - "weight", - torch.empty((out_features, in_features // 2), dtype=torch.uint8) - ) + self.register_buffer( + "weight", + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + ) self.register_buffer( "scales_and_zeros", torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) ) def forward(self, input: torch.Tensor) -> torch.Tensor: - input = input.to(torch.bfloat16) if self.padding: import torch.nn.functional as F input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) @@ -819,17 +817,17 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.weight, self.scales_and_zeros, self.out_features, self.groupsize ) -def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=True, skip_layer_func = None): +def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None): for name, child in module.named_children(): if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)): if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: setattr(module, name, WeightOnlyInt4Linear( child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda + groupsize=groupsize, inner_k_tiles=inner_k_tiles, )) else: - replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda, skip_layer_func) + replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func) class Int4WeightOnlyQuantizer(Quantizer): def __init__( @@ -837,6 +835,7 @@ def __init__( groupsize: int = 256, padding_allowed: bool = True, inner_k_tiles: Optional[int] = 8, + device: torch.device = torch.device("cuda"), ) -> None: super().__init__() assert inner_k_tiles in [2, 4, 8] @@ -845,6 +844,7 @@ def __init__( self.inner_k_tiles = inner_k_tiles self.groupsize: int = groupsize self.padding_allowed: bool = padding_allowed + self.device: torch.device = device @torch.no_grad() def _create_quantized_state_dict( @@ -885,9 +885,9 @@ def _create_quantized_state_dict( 4, # n_bit self.groupsize, ) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to("cuda"), self.inner_k_tiles) - cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cuda") - cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cuda") + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(self.device) return cur_state_dict def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: @@ -916,12 +916,14 @@ def __init__( groupsize, inner_k_tiles=8, padding_allowed=True, + device: torch.device = torch.device("cuda"), ): self.blocksize = blocksize self.percdamp = percdamp self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles self.padding_allowed = padding_allowed + self.device = device self.act_fake_quant_func = None n_bit = 4 self.get_qparams_func = lambda w: get_groupwise_affine_qparams( @@ -956,10 +958,10 @@ def make_names_and_values_dict_func(q, qparams): new_k = k # how much we need to pad the weight delta_k = new_k - q.shape[1] - q = q.to(torch.int32) + q = q.to(torch.int32).to(self.device) final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) - scales = qparams[0].to(torch.bfloat16) - zeros = qparams[1].to(torch.bfloat16) + scales = qparams[0].to(torch.bfloat16).to(self.device) + zeros = qparams[1].to(torch.bfloat16).to(self.device) scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) # how many new groups we need for padded weight delta_groups = new_k // groupsize - scales_and_zeros.shape[0] @@ -970,13 +972,12 @@ def make_names_and_values_dict_func(q, qparams): super().__init__() def _convert_for_runtime(self, model): - # TODO: temporary path for gpt-fast, will remove later replace_linear_int4( model, self.groupsize, self.inner_k_tiles, self.padding_allowed, - skip_layer_func = self.skip_layer_func, + skip_layer_func=self.skip_layer_func, ) return model @@ -1160,7 +1161,6 @@ def __init__( scales_precision: torch.dtype = torch.float32, inner_k_tiles: Optional[int] = None, _is_gpt_fast: bool = False, - _use_cuda: bool = True, ) -> None: super().__init__() if _is_gpt_fast: @@ -1169,7 +1169,6 @@ def __init__( else: assert inner_k_tiles is None self._is_gpt_fast = _is_gpt_fast - self._use_cuda = _use_cuda self.inner_k_tiles = inner_k_tiles self.groupsize: int = groupsize self.padding_allowed: bool = padding_allowed @@ -1238,7 +1237,6 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: self.groupsize, self.inner_k_tiles, self.padding_allowed, - self._use_cuda, ) else: replace_linear_8da4w( @@ -1270,10 +1268,8 @@ def __init__( padding_allowed=True, precision=torch.float32, _is_gpt_fast=False, - _use_cuda=True, ): self._is_gpt_fast = _is_gpt_fast - self._use_cuda = _use_cuda self.blocksize = blocksize self.percdamp = percdamp self.groupsize = groupsize @@ -1352,7 +1348,6 @@ def _convert_for_runtime(self, model): self.groupsize, self.inner_k_tiles, self.padding_allowed, - self._use_cuda, ) else: replace_linear_8da4w( From 5401df093564825c06691f4c2c10cdcf1a32a40c Mon Sep 17 00:00:00 2001 From: supriyar Date: Mon, 15 Apr 2024 21:19:12 -0700 Subject: [PATCH 20/21] Update README.md (#140) * Update README.md Update to include 1. the interoperability details 2. our goals 3. some general cleanup --- README.md | 52 +++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 00609bf298..5d7e1d8f47 100644 --- a/README.md +++ b/README.md @@ -2,17 +2,36 @@ **Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue** -The `torchao` package allows you to quantize and prune your models using native PyTorch. +## Introduction -The repo hosts both -1. lower precision [dtypes](./torchao/dtypes) such as nf4, uint4 -2. Quantization [algorithms](./torchao/quantization) such as dynamic quant, smoothquant -3. Sparsity [algorithms](./torchao/sparsity) such as Wanda +torchao is a PyTorch native library for optimizing your models using lower precision dtypes, techniques like quantization and sparsity and performant kernels. + +The library provides +1. Support for lower precision [dtypes](./torchao/dtypes) such as nf4, uint4 that are torch.compile friendly +2. Quantization [algorithms](./torchao/quantization) such as dynamic quant, smoothquant, GPTQ that run on CPU/GPU and Mobile. +3. Sparsity [algorithms](./torchao/sparsity) such as Wanda that help improve accuracy of sparse networks +4. Integration with other PyTorch native libraries like torchtune and ExecuTorch + +## Key Features +* Native PyTorch techniques, composable with torch.compile +* High level `autoquant` API and kernel auto tuner targeting SOTA performance across varying model shapes on consumer/enterprise GPUs. +* Quantization techniques and kernels that work with both eager and torch.compile + * Int8 dynamic activation quantization + * Int8 and int4 weight-only quantization + * Int8 dynamic activation quantization with int4 weight quantization + * [GPTQ](https://arxiv.org/abs/2210.17323) and [Smoothquant](https://arxiv.org/abs/2211.10438) + +## Interoperability with PyTorch Libraries + +torchao has been integrated with other repositories to ease usage + +* [torchtune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md) is integrated with 8 and 4 bit weight-only quantization techniques with and without GPTQ. +* [Executorch](https://github.com/pytorch/executorch/tree/main/examples/models/llama2#quantization) is integrated with GPTQ for both 8da4w (int8 dynamic activation, with int4 weight) and int4 weight only quantization. ## Success stories Our kernels have has been used to achieve SOTA inference performance on -1. Image segmentation modelss with [sam-fast](pytorch.org/blog/accelerating-generative-ai) +1. Image segmentation models with [sam-fast](pytorch.org/blog/accelerating-generative-ai) 2. Language models with [gpt-fast](pytorch.org/blog/accelerating-generative-ai-2) 3. Diffusion models with [sd-fast](pytorch.org/blog/accelerating-generative-ai-3) @@ -34,10 +53,23 @@ cd ao pip install -e . ``` +## Our Goals +torchao embodies PyTorch’s design philosophy [details](https://pytorch.org/docs/stable/community/design.html), especially "usability over everything else". Our vision for this repository is the following: + +* Composability: Native solutions for optimization techniques that compose with both `torch.compile` and `FSDP` + * For example, for QLoRA for new dtypes support +* Interoperability: Work with the rest of the PyTorch ecosystem such as torchtune, gpt-fast and ExecuTorch +* Transparent Benchmarks: Regularly run performance benchmarking of our APIs across a suite of Torchbench models and across hardware backends +* Heterogeneous Hardware: Efficient kernels that can run on CPU/GPU based server (w/ torch.compile) and mobile backends (w/ ExecuTorch). +* Infrastructure Support: Release packaging solution for kernels and a CI/CD setup that runs these kernels on different backends. + + + ## Examples Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change. + ### Autoquantization The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes @@ -133,10 +165,12 @@ model = torch.compile(model, mode='max-autotune') model(input) ``` -## Sharp edges -1. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance. -2. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible. +## Notes + +1. APIs have been hardware tested on A100 and T4(colab) +2. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance. +3. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible. ## License From d76ecc2df80c37c72dcdd2ed8e55202756a075c8 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 16 Apr 2024 10:34:04 -0700 Subject: [PATCH 21/21] Fix intmm benchmark script (#141) --- benchmarks/intmm.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/benchmarks/intmm.py b/benchmarks/intmm.py index 7f1df1bee1..edadf2a7cf 100644 --- a/benchmarks/intmm.py +++ b/benchmarks/intmm.py @@ -2,12 +2,21 @@ import csv import itertools import math +import sys import pathlib import torch +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4, TORCH_VERSION_AFTER_2_2 + + +# Check if CUDA is available, if not, exit the script +if not torch.cuda.is_available(): + print("CUDA is not available. Exiting the script.") + sys.exit(0) + import torch.nn.functional as F import torch.utils.benchmark as benchmark -from torchao.kernel.intmm_triton import int_matmul, int_scaled_matmul +from torchao.kernel.intmm import int_matmul, int_scaled_matmul torch._dynamo.config.cache_size_limit = 128 torch._dynamo.config.accumulated_cache_size_limit = 128 @@ -81,7 +90,7 @@ def run_benchmarks(shapes): if __name__ == "__main__": parser = argparse.ArgumentParser(description="integer matmul benchmarks") - parser.add_argument("file_path", type=str, help="Path to csv file with shapes") + parser.add_argument("--file_path", type=str, required=True, help="Path to csv file with shapes") args = parser.parse_args() # Access the file path provided as an argument file_path = args.file_path