diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index a1bee9a23b..0a53eb911e 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -10,22 +10,42 @@ on: jobs: test: - runs-on: 4-core-ubuntu-gpu-t4 + strategy: + matrix: + include: + - name: CUDA 2.2.2 + runs-on: 4-core-ubuntu-gpu-t4 + torch-spec: 'torch==2.2.2' + - name: CUDA 2.3 RC + runs-on: 4-core-ubuntu-gpu-t4 + torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/test/cu121' + - name: CUDA Nightly + runs-on: 4-core-ubuntu-gpu-t4 + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' + - name: CPU 2.2.2 + runs-on: 32-core-ubuntu + torch-spec: 'torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu' + - name: CPU 2.3 RC + runs-on: 32-core-ubuntu + torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/test/cpu' + - name: Nightly CPU + runs-on: 32-core-ubuntu + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu' + runs-on: ${{ matrix.runs-on }} steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: '3.9' - name: Install dependencies run: | python -m pip install --upgrade pip - pip install torch + pip install ${{ matrix.torch-spec }} pip install -r requirements.txt pip install -r dev-requirements.txt - - name: Install package run: | @@ -33,82 +53,4 @@ jobs: - name: Run tests run: | - pytest test --verbose -s -x - - test-nightly: - runs-on: 4-core-ubuntu-gpu-t4 - steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 - pip install -r requirements.txt - pip install -r dev-requirements.txt - - - - name: Install package - run: | - pip install . - - - name: Run tests - run: | - pytest test --verbose -s -x - - test-cpu: - runs-on: 32-core-ubuntu - steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install -r requirements.txt - pip install -r dev-requirements.txt - - - - name: Install package - run: | - pip install . - - - name: Run tests - run: | - pytest test --verbose -s -x - - test-nightly-cpu: - runs-on: 32-core-ubuntu - steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu - pip install -r requirements.txt - pip install -r dev-requirements.txt - - - - name: Install package - run: | - pip install . - - - name: Run tests - run: | - pytest test --verbose -s -x + pytest test --verbose -s diff --git a/README.md b/README.md index 8adcd9c24c..5d7e1d8f47 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,37 @@ -# torchao: PyTorch Architecture Optimization +# torchao: PyTorch Architecture Optimization **Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue** -The `torchao` package allows you to quantize and prune your models using native PyTorch. +## Introduction -The repo hosts both -1. lower precision [dtypes](./torchao/dtypes) such as nf4, uint4 -2. Quantization [algorithms](./torchao/quantization) such as dynamic quant, smoothquant -3. Sparsity [algorithms](./torchao/sparsity) such as Wanda +torchao is a PyTorch native library for optimizing your models using lower precision dtypes, techniques like quantization and sparsity and performant kernels. + +The library provides +1. Support for lower precision [dtypes](./torchao/dtypes) such as nf4, uint4 that are torch.compile friendly +2. Quantization [algorithms](./torchao/quantization) such as dynamic quant, smoothquant, GPTQ that run on CPU/GPU and Mobile. +3. Sparsity [algorithms](./torchao/sparsity) such as Wanda that help improve accuracy of sparse networks +4. Integration with other PyTorch native libraries like torchtune and ExecuTorch + +## Key Features +* Native PyTorch techniques, composable with torch.compile +* High level `autoquant` API and kernel auto tuner targeting SOTA performance across varying model shapes on consumer/enterprise GPUs. +* Quantization techniques and kernels that work with both eager and torch.compile + * Int8 dynamic activation quantization + * Int8 and int4 weight-only quantization + * Int8 dynamic activation quantization with int4 weight quantization + * [GPTQ](https://arxiv.org/abs/2210.17323) and [Smoothquant](https://arxiv.org/abs/2211.10438) + +## Interoperability with PyTorch Libraries + +torchao has been integrated with other repositories to ease usage + +* [torchtune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md) is integrated with 8 and 4 bit weight-only quantization techniques with and without GPTQ. +* [Executorch](https://github.com/pytorch/executorch/tree/main/examples/models/llama2#quantization) is integrated with GPTQ for both 8da4w (int8 dynamic activation, with int4 weight) and int4 weight only quantization. ## Success stories Our kernels have has been used to achieve SOTA inference performance on -1. Image segmentation modelss with [sam-fast](pytorch.org/blog/accelerating-generative-ai) +1. Image segmentation models with [sam-fast](pytorch.org/blog/accelerating-generative-ai) 2. Language models with [gpt-fast](pytorch.org/blog/accelerating-generative-ai-2) 3. Diffusion models with [sd-fast](pytorch.org/blog/accelerating-generative-ai-3) @@ -34,34 +53,63 @@ cd ao pip install -e . ``` +## Our Goals +torchao embodies PyTorch’s design philosophy [details](https://pytorch.org/docs/stable/community/design.html), especially "usability over everything else". Our vision for this repository is the following: + +* Composability: Native solutions for optimization techniques that compose with both `torch.compile` and `FSDP` + * For example, for QLoRA for new dtypes support +* Interoperability: Work with the rest of the PyTorch ecosystem such as torchtune, gpt-fast and ExecuTorch +* Transparent Benchmarks: Regularly run performance benchmarking of our APIs across a suite of Torchbench models and across hardware backends +* Heterogeneous Hardware: Efficient kernels that can run on CPU/GPU based server (w/ torch.compile) and mobile backends (w/ ExecuTorch). +* Infrastructure Support: Release packaging solution for kernels and a CI/CD setup that runs these kernels on different backends. + + + ## Examples Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change. -### A8W8 Dynamic Quantization -```Python +### Autoquantization + +The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes +of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer. + +```python import torch -from torchao.quantization import quant_api +import torchao -# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor +# inductor settings which improve torch.compile performance for quantized modules torch._inductor.config.force_fuse_int_mm_with_mul = True +torch._inductor.config.use_mixed_mm = True # Plug in your model and example input model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') -# convert linear modules to quantized linear modules -quant_api.change_linear_weights_to_int8_dqtensors(model) +# perform autoquantization +torchao.autoquant(model, (input)) # compile the model to improve performance model = torch.compile(model, mode='max-autotune') model(input) ``` + +### A8W8 Dynamic Quantization + +```python +# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor +torch._inductor.config.force_fuse_int_mm_with_mul = True +from torchao.quantization import quant_api +# convert linear modules to quantized tensor subclasses +quant_api.change_linear_weights_to_int8_dqtensors(model) +``` + ### A16W8 WeightOnly Quantization ```python +from torchao.quantization import quant_api quant_api.change_linear_weights_to_int8_woqtensors(model) ``` @@ -71,6 +119,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is ### A16W4 WeightOnly Quantization ```python +from torchao.quantization import quant_api quant_api.change_linear_weights_to_int4_woqtensors(model) ``` @@ -116,10 +165,12 @@ model = torch.compile(model, mode='max-autotune') model(input) ``` -## Sharp edges -1. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance. -2. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible. +## Notes + +1. APIs have been hardware tested on A100 and T4(colab) +2. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance. +3. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible. ## License diff --git a/benchmarks/intmm.py b/benchmarks/intmm.py index 7f1df1bee1..edadf2a7cf 100644 --- a/benchmarks/intmm.py +++ b/benchmarks/intmm.py @@ -2,12 +2,21 @@ import csv import itertools import math +import sys import pathlib import torch +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4, TORCH_VERSION_AFTER_2_2 + + +# Check if CUDA is available, if not, exit the script +if not torch.cuda.is_available(): + print("CUDA is not available. Exiting the script.") + sys.exit(0) + import torch.nn.functional as F import torch.utils.benchmark as benchmark -from torchao.kernel.intmm_triton import int_matmul, int_scaled_matmul +from torchao.kernel.intmm import int_matmul, int_scaled_matmul torch._dynamo.config.cache_size_limit = 128 torch._dynamo.config.accumulated_cache_size_limit = 128 @@ -81,7 +90,7 @@ def run_benchmarks(shapes): if __name__ == "__main__": parser = argparse.ArgumentParser(description="integer matmul benchmarks") - parser.add_argument("file_path", type=str, help="Path to csv file with shapes") + parser.add_argument("--file_path", type=str, required=True, help="Path to csv file with shapes") args = parser.parse_args() # Access the file path provided as an argument file_path = args.file_path diff --git a/setup.py b/setup.py index 9f5f8e7745..27c1f260e8 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ def read_requirements(file_path): package_name = "torchao-nightly" if os.environ.get("TORCHAO_NIGHTLY") else "torchao" # Version is year.month.date if using nightlies -version = current_date if package_name == "torchao-nightly" else "0.0.3" +version = current_date if package_name == "torchao-nightly" else "0.1" setup( diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 5cd967f9d2..c71fcd25b6 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -192,7 +192,6 @@ def test_smoketest_linear(self): out1 = torch.nn.functional.linear(inp, a) out2 = torch.nn.functional.linear(inp, a_nf4) - @unittest.skipIf(torch.__version__.split('+')[0] == '2.2.1', "Broken on stable.") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_smoketest_linear_compile(self): for dtype in [torch.bfloat16, torch.float16]: diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index fd3a3311df..2425d341e2 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -7,11 +7,13 @@ # mypy: ignore-errors import copy import unittest +import itertools import torch import torch.nn as nn from torch._inductor.utils import run_and_get_code from torch._dynamo import config +import torchao from torch.ao.quantization import MinMaxObserver, QConfigMapping from torchao.quantization.dynamic_quant import ( @@ -54,10 +56,17 @@ _fqn_to_op_to_shape_to_count, LoggingTensorMode, ) +from torchao.quantization.autoquant import ( + AQInt8DynamicallyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight2, + AQWeightOnlyQuantizedLinearWeight3 + +) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os from parameterized import parameterized -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 torch.manual_seed(0) config.cache_size_limit = 100 @@ -71,6 +80,12 @@ ("cuda", torch.bfloat16), ] +def combine_parameters(a, b): + new_tuples = [] + for (tuple1, tuple2) in itertools.product(a, b): + new_tuples.append(tuple1 + tuple2) + return new_tuples + def run_supported_device_dtype(test_method): def wrapper(*args, **kwargs): if args[2] == "cuda" and not torch.cuda.is_available(): @@ -836,7 +851,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if dtype != torch.bfloat16: self.skipTest("Currently only supports bfloat16.") @@ -846,7 +861,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest("Currently only supports bfloat16.") @@ -908,7 +923,37 @@ def test_int8_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + def test_aq_int8_dynamic_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype + ) + + @parameterized.expand(COMMON_DEVICE_DTYPE) + def test_aq_int8_weight_only_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQInt8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype + ) + + @parameterized.expand(COMMON_DEVICE_DTYPE) + def test_aq_int8_weight_only_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQWeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype + ) + + @parameterized.expand(COMMON_DEVICE_DTYPE) + def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQWeightOnlyQuantizedLinearWeight2.from_float, device, 35, test_dtype=dtype + ) + + @parameterized.expand(COMMON_DEVICE_DTYPE) + def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype + ) + + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_int4_weight_only_quant_subclass(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -918,7 +963,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -981,7 +1026,7 @@ def test_int8_weight_only_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_int4_weight_only_quant_subclass_api(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -995,7 +1040,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1159,7 +1204,7 @@ def test_save_load_int8woqtensors(self, device, dtype): self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "int4 requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") @torch.no_grad() def test_save_load_int4woqtensors(self, device, dtype): if dtype != torch.bfloat16: @@ -1169,7 +1214,7 @@ def test_save_load_int4woqtensors(self, device, dtype): class TorchCompileUnitTest(unittest.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "fullgraph requires torch nightly.") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "fullgraph requires torch nightly.") def test_fullgraph(self): lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( @@ -1290,6 +1335,74 @@ def test_on_dummy_distilbert(self): print("sqnr_pt_quant", sqnr_pt_quant) self.assertTrue(sqnr_sq >= 8.0) +class TestAutoQuant(unittest.TestCase): + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, + [ + (16, 128, 128), + (64, 128, 128), + # (2**15, 128, 128), TODO: Runs out of shared memory on T4 + (16, 128, 256), + # (64, 128, 256), # TODO: Runs out of shared memory on T4 + (16, 256, 128), + (64, 256, 128), + # (256, 256, 128), TODO: Runs out of shared memory on T4 + ])) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") + def test_autoquant_one_input(self, device, dtype, m, k, n): + print("(m, k, n): ", (m, k, n)) + if device != "cuda" or not torch.cuda.is_available(): + self.skipTest(f"autoquant currently does not support {device}") + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): + if dtype == torch.bfloat16: + self.skipTest(f"bfloat16 requires sm80+") + if m == 1: + self.skipTest(f"Shape {(m, k, n)} requires sm80+") + torch._inductor.config.epilogue_fusion = False + torch._inductor.config.use_mixed_mm = True + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._dynamo.config.automatic_dynamic_shapes = False + + example_input = torch.randn(m, k, device=device, dtype=dtype) + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to(device).to(dtype) + out = model(example_input) + torchao.autoquant(model, example_input) + out2 = model(example_input) + sqnr = SQNR(out, out2) + self.assertTrue(sqnr >= 30) + + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, + [ + (1, 1, 128, 128), + (1, 32, 128, 128), + (32, 32, 128, 128), + ])) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") + def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n): + if device != "cuda" or not torch.cuda.is_available(): + self.skipTest(f"autoquant currently does not support {device}") + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): + if dtype == torch.bfloat16: + self.skipTest(f"bfloat16 requires sm80+") + if m1 == 1 or m2 == 1: + self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+") + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to(device).to(dtype) + example_input = torch.randn(m1, k, device=device, dtype=dtype) + example_input2 = torch.randn(m2, k, device=device, dtype=dtype) + torchao.quantization.change_linears_to_autoquantizable(model) + out=model(example_input) + model(example_input2) + torchao.quantization.change_autoquantizable_to_quantized(model) + out2 = model(example_input) + sqnr = SQNR(out, out2) + self.assertTrue(sqnr >= 30) if __name__ == "__main__": unittest.main() diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index 33678995af..82fb117363 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -52,13 +52,15 @@ def test_int_mm(self, device, dtype): @parameterized.expand( [ ("cuda", torch.bfloat16), - # TODO: ("cpu", torch.bfloat16), + ("cpu", torch.bfloat16), ("cuda", torch.float16), - # TODO: ("cpu", torch.float16), + ("cpu", torch.float16), ] ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int_scaled_mm(self, device, dtype): + if device == "cuda" and not torch.cuda.is_available(): + self.skipTest(f"{device} not available") + from torchao.kernel import intmm dtype = torch.bfloat16 diff --git a/test/quantization/model.py b/test/quantization/model.py index 17a59e5bb0..e851901c41 100644 --- a/test/quantization/model.py +++ b/test/quantization/model.py @@ -10,12 +10,22 @@ import torch.nn as nn from torch import Tensor from torch.nn import functional as F - - -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) +from torchao.quantization.utils import find_multiple + +def prepare_inputs_for_model(inps, max_new_tokens=1): + # this is because input from lm-eval is 2d + if inps.dim() != 2: + raise ValueError(f"Expected input to be of dim 2, but got {inps.dim()}") + + inps = inps.squeeze(0) + # setup inputs in correct format + T = inps.size(0) + T_new = T + max_new_tokens + seq = torch.empty(T_new, dtype=inps.dtype, device=inps.device) + seq[:T] = inps + input_pos = torch.arange(0, T, device=inps.device) + x = seq.index_select(0, input_pos).view(1, -1) + return (x, input_pos) @dataclass class ModelArgs: @@ -76,10 +86,8 @@ def update(self, input_pos, k_val, v_val): # input_pos: [S], k_val: [B, H, S, D] assert input_pos.shape[0] == k_val.shape[2] - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val + k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val) + v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val) return k_out, v_out diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index cb5b8344ca..d772fda831 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -8,6 +8,7 @@ # This test takes a long time to run import unittest import torch +import os from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import ( prepare_pt2e, @@ -18,18 +19,20 @@ get_symmetric_quantization_config, ) -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.quantization.quant_api import apply_dynamic_quant from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, + apply_dynamic_quant, + apply_weight_only_int8_quant, Quantizer, TwoStepQuantizer, ) from torchao.quantization.utils import ( + TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, ) from pathlib import Path from sentencepiece import SentencePieceProcessor -from model import Transformer +from model import Transformer, prepare_inputs_for_model def dynamic_quant(model, example_inputs): @@ -136,12 +139,32 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): compiled = m(*example_inputs) torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_int8_wo_quant_save_load(self): + m = M().eval().cpu() + apply_weight_only_int8_quant(m) + example_inputs = m.example_inputs() + ref = m(*example_inputs) + _TMP_FN = "_test.pt" + torch.save(m.state_dict(), _TMP_FN) + + state_dict = torch.load(_TMP_FN) + os.remove(_TMP_FN) + m2 = M().eval() + apply_weight_only_int8_quant(m2) + m2.load_state_dict(state_dict) + m2 = m2.to(device="cuda") + example_inputs = map(lambda x: x.cuda(), example_inputs) + res = m2(*example_inputs) + + torch.testing.assert_close(ref, res.cpu()) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_8da4w_quantizer(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - from torchao.quantization.quant_api import Int8DynActInt4WeightLinear + from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear - quantizer = Int8DynActInt4WeightQuantizer(group_size=32) + quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) m = M().eval() example_inputs = m.example_inputs() m = quantizer.quantize(m) @@ -150,8 +173,8 @@ def test_8da4w_quantizer(self): m(*example_inputs) @unittest.skip("skipping until we get checkpoints for gpt-fast") - def test_gptq_quantizer(self): - from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer + def test_8da4w_gptq_quantizer(self): + from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder, TransformerEvalWrapper # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 device = "cpu" @@ -160,6 +183,7 @@ def test_gptq_quantizer(self): checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) model = model.to(dtype=precision, device=device) + model.eval() tokenizer_path = checkpoint_path.parent / "tokenizer.model" assert tokenizer_path.is_file(), tokenizer_path tokenizer = SentencePieceProcessor( # pyre-ignore[28] @@ -169,20 +193,255 @@ def test_gptq_quantizer(self): percdamp = 0.01 groupsize = 128 calibration_tasks = ["wikitext"] - calibration_limit = 5 + calibration_limit = 1 calibration_seq_length = 100 + input_prep_func = prepare_inputs_for_model pad_calibration_inputs = False - quantizer = Int8DynActInt4WeightGPTQQuantizer( + + inputs = InputRecorder( tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + ).record_inputs( + calibration_tasks, + calibration_limit, + ).get_inputs() + + quantizer = Int8DynActInt4WeightGPTQQuantizer( blocksize, percdamp, groupsize, + precision=precision, + ) + model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) + model = quantizer.quantize(model, inputs) + result=TransformerEvalWrapper( + model, + tokenizer, + model.config.block_size, + prepare_inputs_for_model, + device, + ).run_eval( + ["wikitext"], + 1, + ) + + assert result['results']['wikitext']['word_perplexity,none'] < 7.88, ( + f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}" + ) + + @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + def test_8da4w_quantizer_eval(self): + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + from torchao.quantization.GPTQ import TransformerEvalWrapper + + precision = torch.bfloat16 + device = "cpu" + checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") + model = Transformer.from_name(checkpoint_path.parent.name) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + model.eval() + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor( # pyre-ignore[28] + model_file=str(tokenizer_path) + ) + + quantizer = Int8DynActInt4WeightQuantizer(groupsize=128, precision=precision) + q_model = quantizer.quantize(model) + result=TransformerEvalWrapper( + q_model, + tokenizer, + q_model.config.block_size, + prepare_inputs_for_model, + device, + ).run_eval( + ["wikitext"], + 1, + ) + assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( + f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" + ) + + @unittest.skip("skipping until we get checkpoints for gpt-fast") + def test_gptq_quantizer_gpt_fast(self): + from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder + # should be similar to TorchCompileDynamicQuantizer + precision = torch.bfloat16 + device = "cuda" + checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") + model = Transformer.from_name(checkpoint_path.parent.name) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor( # pyre-ignore[28] + model_file=str(tokenizer_path) + ) + blocksize = 128 + percdamp = 0.01 + groupsize = 128 + calibration_tasks = ["wikitext"] + calibration_limit = 1 + calibration_seq_length = 100 + input_prep_func = prepare_inputs_for_model + pad_calibration_inputs = False + + inputs = InputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + ).record_inputs( calibration_tasks, calibration_limit, + ).get_inputs() + + quantizer = Int8DynActInt4WeightGPTQQuantizer( + blocksize, + percdamp, + groupsize, + _is_gpt_fast=True, + _use_cuda=True, + ) + + model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) + + model = quantizer.quantize(model, inputs) + compiled = torch.compile(model, mode="max-autotune") + with torch.no_grad(): + compiled(inputs[0].values[0], inputs[1].values[0]) + + @unittest.skip("skipping until we get checkpoints for gpt-fast") + def test_gptq_quantizer_int4wo(self): + from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper + precision = torch.bfloat16 + device = "cuda" + checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") + model = Transformer.from_name(checkpoint_path.parent.name) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device="cpu") + model.eval() + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor( # pyre-ignore[28] + model_file=str(tokenizer_path) + ) + blocksize = 128 + percdamp = 0.01 + groupsize = 128 + calibration_tasks = ["wikitext"] + calibration_limit = 1 + calibration_seq_length = 100 + input_prep_func = prepare_inputs_for_model + pad_calibration_inputs = False + + inputs = InputRecorder( + tokenizer, calibration_seq_length, + input_prep_func, pad_calibration_inputs, + model.config.vocab_size, + device="cpu", + ).record_inputs( + calibration_tasks, + calibration_limit, + ).get_inputs() + + quantizer = Int4WeightOnlyGPTQQuantizer( + blocksize, + percdamp, + groupsize, + ) + model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) + + model = quantizer.quantize(model, inputs).cuda() + result = TransformerEvalWrapper( + model.cuda(), + tokenizer, + model.config.block_size, + prepare_inputs_for_model, + device, + ).run_eval( + ["wikitext"], + 1, + ) + assert result['results']['wikitext']['word_perplexity,none'] < 7.77, ( + f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" + ) + + @unittest.skip("skipping until we get checkpoints for gpt-fast") + def test_quantizer_int4wo(self): + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer, TransformerEvalWrapper + precision = torch.bfloat16 + device = "cuda" + checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") + model = Transformer.from_name(checkpoint_path.parent.name) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + model.eval() + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor( # pyre-ignore[28] + model_file=str(tokenizer_path) + ) + groupsize = 128 + quantizer = Int4WeightOnlyQuantizer( + groupsize, + ) + model = quantizer.quantize(model).cuda() + result = TransformerEvalWrapper( + model, + tokenizer, + model.config.block_size, + prepare_inputs_for_model, + device, + ).run_eval( + ["wikitext"], + 1, + ) + assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( + f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" + ) + + @unittest.skip("skipping until we get checkpoints for gpt-fast") + def test_eval_wrapper(self): + from torchao.quantization.GPTQ import TransformerEvalWrapper + precision = torch.bfloat16 + device = "cuda" + checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") + model = Transformer.from_name(checkpoint_path.parent.name) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + model.eval() + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor( # pyre-ignore[28] + model_file=str(tokenizer_path) + ) + result=TransformerEvalWrapper( + model, + tokenizer, + model.config.block_size, + prepare_inputs_for_model, + device, + ).run_eval( + ["wikitext"], + 1, + ) + assert result['results']['wikitext']['word_perplexity,none']<7.77, ( + f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" ) - model = quantizer.quantize(model) if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py new file mode 100644 index 0000000000..fef3f83dcc --- /dev/null +++ b/test/quantization/test_quant_primitives.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# mypy: ignore-errors +# This test takes a long time to run +import unittest +import torch +from torchao.quantization.quant_primitives import get_group_qparams_symmetric +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 + +class TestQuantPrimitives(unittest.TestCase): + SEED = 123 + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower") + def test_get_group_qparams_symmetric(self): + """ + Test that `get_group_qparams_symmetric` produces the exact same scales as + `PerChannelMinMaxObserver._calculate_qparams`. + """ + n_bit = 4 + qmin = -(2 ** (n_bit - 1)) + qmax = 2 ** (n_bit - 1) - 1 + eps = torch.finfo(torch.float32).eps + groupsize = 256 + torch.manual_seed(self.SEED) + weight = torch.randn(100, 256).to(torch.float16) + + # calculate observer scales + obs = torch.ao.quantization.PerChannelMinMaxObserver( + ch_axis=0, + qscheme=torch.per_channel_symmetric, + quant_min=qmin, + quant_max=qmax, + # This is needed to ensure `min_val` and `max_val` are fp16, + # otherwise they default to fp32 and the qparams will be slightly off + factory_kwargs={"dtype": torch.float16} + ) + obs(weight) + (scale_obs, _) = obs.calculate_qparams() + scale_obs = scale_obs.reshape(weight.shape[0], -1) + + # assert that scales are identical + (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize) + torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0) + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/__init__.py b/torchao/__init__.py index b0dc9b1e1e..ecd2ccf4b9 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -1,8 +1,13 @@ +from torchao.quantization import ( + apply_weight_only_int8_quant, + apply_dynamic_quant, + autoquant, +) from . import dtypes -from .quantization.quant_api import apply_dynamic_quant -from .quantization.quant_api import apply_weight_only_int8_quant __all__ = [ - "dtypes", - "apply_dynamic_quant", + "dtypes", + "apply_dynamic_quant", + "apply_weight_only_int8_quant", + "autoquant", ] diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 34ad778027..ea45a6c0de 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -569,9 +569,9 @@ def forward(ctx, input: torch.Tensor, weight: NF4Tensor): # inconsistently. def backward(ctx, grad_output): - """The nf4 weight will never require grad so we can just return the grad_output @ weight.get_original_weight()""" + """The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)""" weight: NF4Tensor = ctx.nf4_weight - return grad_output @ weight.get_original_weight(), None + return grad_output @ weight.to(grad_output.dtype), None def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index bdd583520e..d2afa66a0a 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -2,52 +2,65 @@ import os import torch -from torch._dynamo import is_compiling as dynamo_is_compiling -from torch._higher_order_ops.out_dtype import out_dtype +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_2 try: - from torchao.kernel import intmm_triton + # Only works for torch2.2 or newer. + if TORCH_VERSION_AFTER_2_2: + from torchao.kernel import intmm_triton + else: + intmm_triton = None except ImportError: + # On cpu-only builds might not be available. intmm_triton = None AUTOTUNER_ENABLE = bool(int(os.getenv("TORCHAO_AUTOTUNER_ENABLE", 0))) -def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: - # torch.compile path - if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - - # error checking for cublas path - assert ( - mat2.device == input.device - ), f"need both tensors to be on the same device but got {mat2.device} and {input.device}" - device_cpu = "cpu" in [mat2.device.type, input.device.type] - # with input.shape = [i,j] and mat2.shape = [j,k] - i_is_strictly_greater_than_16 = input.shape[0] > 16 - j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) - k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) - bad_dimensions_for_cublas = not ( - i_is_strictly_greater_than_16 - and j_is_nonzero_multiple_of_8 - and k_is_nonzero_multiple_of_8 - ) - - if device_cpu or bad_dimensions_for_cublas: - # fallback path - return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( - input.device.type +# torch._int_mm doesn't exist before 2.2 +if TORCH_VERSION_AFTER_2_2: + from torch._dynamo import is_compiling as dynamo_is_compiling + from torch._higher_order_ops.out_dtype import out_dtype + def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: + # torch.compile path + if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + + # error checking for cublas path + assert ( + mat2.device == input.device + ), f"need both tensors to be on the same device but got {mat2.device} and {input.device}" + device_cpu = "cpu" in [mat2.device.type, input.device.type] + # with input.shape = [i,j] and mat2.shape = [j,k] + i_is_strictly_greater_than_16 = input.shape[0] > 16 + j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) + k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) + bad_dimensions_for_cublas = not ( + i_is_strictly_greater_than_16 + and j_is_nonzero_multiple_of_8 + and k_is_nonzero_multiple_of_8 ) - - # cublas paths - if not mat2.is_contiguous(): # silently gives incorrect result without this - mat2 = mat2.contiguous() - if (not input.is_contiguous()) and ( - input.shape[0] % 8 != 0 - ): # gives cryptic error without this - input = ( - input.contiguous() - ) # (it seems the transpose makes cublas check the above j constraint on i) - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + + if device_cpu or bad_dimensions_for_cublas: + # fallback path + return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( + input.device.type + ) + + # cublas paths + if not mat2.is_contiguous(): # silently gives incorrect result without this + mat2 = mat2.contiguous() + if (not input.is_contiguous()) and ( + input.shape[0] % 8 != 0 + ): # gives cryptic error without this + input = ( + input.contiguous() + ) # (it seems the transpose makes cublas check the above j constraint on i) + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) +else: + def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: + # We can improve on this by writing Triton code that works for older versions of Triton + # that ship with 2.1 or 2.0. + return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) def int_matmul(a, b): diff --git a/torchao/kernel/intmm_triton.py b/torchao/kernel/intmm_triton.py index 4e84d9cd3c..d10dac0abe 100644 --- a/torchao/kernel/intmm_triton.py +++ b/torchao/kernel/intmm_triton.py @@ -356,3 +356,9 @@ def int_scaled_matmul_cuda(a, b, scales1): int_scaled_matmul_kernel, [a, b, scales1, c], int8_mm_kernel_configs ) return int_scaled_matmul_kernel(a, b, scales1, c, best_config) + + +@torch.library.impl(lib, "int_scaled_matmul", "CPU") +def int_scaled_matmul_cpu(a, b, scales1): + c = torch._int_mm(a, b) + return c.to(scales1.dtype) * scales1 diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index a82edca528..91e604c8cf 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -9,7 +9,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Optional +from typing import Optional, List import torch @@ -17,27 +17,21 @@ import torch.nn as nn import torch.nn.functional as F -# from model import Transformer # pyre-ignore[21] from torch.utils._pytree import tree_flatten, tree_unflatten - +from .utils import TORCH_VERSION_AFTER_2_3, find_multiple +from typing import Any, Dict, Optional +from .unified import Quantizer + +from .quant_primitives import ( + get_groupwise_affine_qparams, + groupwise_affine_quantize_tensor_from_qparams, + groupwise_affine_dequantize_tensor_from_qparams, + pack_tinygemm_scales_and_zeros, + groupwise_affine_quantize_tensor, +) aten = torch.ops.aten -## generate.py ## - - -def encode_tokens(tokenizer, string, bos=True, device="cuda"): - - tokens = tokenizer.encode(string) - if bos: - tokens = [tokenizer.bos_id()] + tokens - return torch.tensor(tokens, dtype=torch.int, device=device) - - -def model_forward(model, x, input_pos): - return model(x, input_pos) - - ## eval.py ## try: @@ -61,72 +55,73 @@ def model_forward(model, x, input_pos): else: logging.info("lm_eval is not installed, GPTQ may not be usable") +add_ons = [] -def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( - model: torch.nn.Module, - prompt: torch.Tensor, - max_new_tokens: int, - max_seq_length: Optional[int] = None, - block_size: int = 2048, -): - """ - Sets up model cache and does some bookkeeping calculations for prompt, input_pos and max_seq_length - that are needed for prefill or model_forward - - Args: - model (torch.nn.Module): The model whose cache gets set up - prompt (torch.Tensor): Tensor of shape (T) with indices of the prompt sequence. - max_new_tokens (int): The desired maximum number of new tokens that can be generated. - max_seq_length (Optional[int], optional): The maximum sequence length allowed. - - Returns: - seq (torch.Tensor): prompt but padded with zeros to size max_seq_length - input_pos (torch.Tensor): tensor of integers in increasing order - max_seq_length (int): The maximum sequence length allowed, updated based on other numbers - """ - T = prompt.size(0) - T_new = T + max_new_tokens - if max_seq_length is None: - max_seq_length = min(T_new, block_size) - - device, dtype = prompt.device, prompt.dtype - # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = prompt - seq = empty - input_pos = torch.arange(0, T, device=device) +if lm_eval_available: + add_ons += ["InputRecorder", "TransformerEvalWrapper"] - # no caches in executorch llama2 7b model? - # with torch.device(device): - # model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) +if TORCH_VERSION_AFTER_2_3: + add_ons += ["Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightGPTQQuantizer"] - return seq, input_pos, max_seq_length +__all__ = [ + "MultiInput", + "Int4WeightOnlyGPTQQuantizer", + "Int4WeightOnlyQuantizer", +] + add_ons if lm_eval_available: - - class GPTFastEvalWrapper(eval_wrapper): # pyre-ignore[11] + class InputRecorder(eval_wrapper): """ - A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. + This is a fake evaluation wrapper from the lm_eval library that just records the inputs + so that they can be used in calibration. + + If pad_calibration_inputs is enabled, the input recorder will take + each input and pad/truncate it down to the calibration_seq_length. + (if using padding you should set the embeddings for the pad_token to 0 + in the model) + + Note: after padding/truncation, input_prep_function is called to bring + it to the proper form to be inserted into a given model. + + If not, it will only truncate inputs to the desired length. """ def __init__( self, - model: torch.nn.Module, tokenizer, - max_seq_length: Optional[int] = None, + calibration_seq_length, + input_prep_func=None, + pad_calibration_inputs=False, + vocab_size=32000, + pad_token=0, + device="cpu", ): super().__init__() - self._model = model - self._tokenizer = tokenizer - self._device = torch.device("cuda") + self._device = torch.device(device) + self.vocab_size = vocab_size + self._max_seq_length = calibration_seq_length + self.calibration_seq_length = calibration_seq_length - self._max_seq_length = 2048 if max_seq_length is None else max_seq_length + # need to take inps and convert to corrent input + # for model + self.input_prep_func = ( + input_prep_func if input_prep_func is not None + else lambda x: (x,) + ) + + self.pad_calibration_inputs = pad_calibration_inputs + self.pad_token = pad_token + + self.inputs = None @property def eot_token_id(self): - return self._tokenizer.eos_id() + try: + return self._tokenizer.eos_id() + except: + return self._tokenizer.eos_id @property def max_length(self): @@ -145,89 +140,19 @@ def device(self): return self._device def tok_encode(self, string: str, **kwargs): - encoded = encode_tokens( - self._tokenizer, string, bos=True, device=self._device - ) - # encoded is a pytorch tensor, but some internal logic in the - # eval harness expects it to be a list instead # TODO: verify this for multi-batch as well - encoded = encoded.tolist() - return encoded + tokens = self._tokenizer.encode(string) + if hasattr(self._tokenizer, "bos_id"): + try: + tokens = [self._tokenizer.bos_id()] + tokens + except: + tokens = [self._tokenizer.bos_id] + tokens + return tokens def tok_decode(self, tokens): decoded = self._tokenizer.decode(tokens) return decoded - def _model_call(self, inps): - print("in model_call") - # TODO: make batches work - inps = inps.squeeze(0) - - max_new_tokens = 1 - seq, input_pos, max_seq_length = ( - setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( - self._model, - inps, - max_new_tokens, - self.max_length, - ) - ) - x = seq.index_select(0, input_pos).view(1, -1) - logits = model_forward(self._model, x, input_pos) - return logits - - def _model_generate(self, context, max_length, eos_token_id): - raise Exception("unimplemented") - - class InputRecorder(GPTFastEvalWrapper): - """ - This is a fake evaluation wrapper that just records the inputs - so that they can be used in calibration. - - If pad_calibration_inputs is enabled, the input recorder will take - each input and pad/truncate it down to the calibration_seq_length. - It will also edit the model embeddings to be zero for the 0 token used - in padding and avoid any inputs with the 0 token. - - If not, it will only truncate inputs to the desired length. - """ - - def __init__( - self, - model: torch.nn.Module, - tokenizer, - calibration_seq_length, - pad_calibration_inputs=False, - ): - super().__init__(model, tokenizer, calibration_seq_length) - self._model = model - - self._tokenizer = tokenizer - self._device = torch.device("cpu") - - self.vocab_size = model.vocab_size - - self.calibration_seq_length = calibration_seq_length - - self.pad_calibration_inputs = pad_calibration_inputs - - self.inputs = None - - if self.pad_calibration_inputs: - # This is needed for the pad_calibration_inputs option - # to work properly, the 0 token's embeddings are set to 0 so that - # the padded inputs will not affect the model numerics. This token isn't used - # commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs - # where it appears - try: - if isinstance(self._model.transformer.wte, nn.Embedding): - self._model.transformer.wte.weight.data[0, :] *= 0 - except: - print( - "Did not find embeddings in model.transformer.wte, disabling padding" - ) - self.pad_calibration_inputs = False - def add_input(self, args): if self.inputs is None: self.inputs = [MultiInput([arg]) for arg in args] @@ -236,7 +161,27 @@ def add_input(self, args): multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) ] - def get_recorded_inputs(self): + def record_inputs( + self, + calibration_tasks, + calibration_limit, + ): + try: + lm_eval.tasks.initialize_tasks() + except: + pass + + task_dict = get_task_dict(calibration_tasks) + print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) + + evaluate( + self, + task_dict, + limit=calibration_limit, + ) + return self + + def get_inputs(self): return self.inputs def _model_call(self, inps): @@ -247,7 +192,7 @@ def _model_call(self, inps): (T < self.calibration_seq_length and not self.pad_calibration_inputs) or # can't use inputs that actually use token we use for padding - (self.pad_calibration_inputs and 0 in inps) + (self.pad_calibration_inputs and self.pad_token in inps) ): # give random output return torch.randn( @@ -258,24 +203,76 @@ def _model_call(self, inps): if T >= self.calibration_seq_length: inps = inps[: self.calibration_seq_length] else: - inps = F.pad(inps, (0, self.calibration_seq_length - T)) - - max_new_tokens = 1 - ( - seq, - input_pos, - max_seq_length, - ) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( - self._model, inps, max_new_tokens, self.max_length - ) - x = seq.index_select(0, input_pos).view(1, -1) - self.add_input((x, input_pos)) + inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T)) + + inps = inps.unsqueeze(0) + model_in = self.input_prep_func(inps) + + self.add_input(model_in) # output `something` with correct shape to keep eval going return torch.randn( (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device ) + def _model_generate(self, context, max_length, eos_token_id): + raise Exception("unimplemented") + + class TransformerEvalWrapper(InputRecorder): + """ + A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. + """ + def __init__( + self, + model, + tokenizer, + max_seq_length, + input_prep_func=None, + device="cuda" + ): + super().__init__(None, None) + self._model = model + self._tokenizer = tokenizer + self._device = torch.device(device) + self._max_seq_length = max_seq_length + + # need to take inps and convert to corrent input + # for model + self.input_prep_func = ( + input_prep_func if input_prep_func is not None + else lambda x: (x,) + ) + + def _model_call(self, inps): + # TODO: make batches work + input = self.input_prep_func(inps) + + max_seq_length = min(inps.size(1), self.max_length) + with torch.device(self._device): + self._model.setup_caches(self.batch_size, max_seq_length) + logits = self._model(*input) + return logits + + def _model_generate(self, context, max_length, eos_token_id): + raise Exception('unimplemented') + + def run_eval(self, tasks, limit): + try: + lm_eval.tasks.initialize_tasks() + except: + pass + + task_dict = get_task_dict(tasks) + print("Evaluating Model On: ", task_dict) + with torch.no_grad(): + result = evaluate( + self, + task_dict, + limit=limit, + ) + for task, res in result["results"].items(): + print(f"{task}: {res}") + return result class MultiInput: @@ -305,8 +302,7 @@ class GenericGPTQRunner(fx.Interpreter): into the state_dict so that the quantized model weights/qparams can be loaded directly into the model. - This class is expected to work in concert with a GPTQSimpleQuantizer - class to define the specific type of quantization being done. + intended to be used in concert with a GPTQQuantizer class to define the quantization mode. """ def __init__( @@ -323,9 +319,9 @@ def __init__( } # trace model for one input - one_input = [multi.values[0] for multi in inputs] # pyre-ignore[16] + one_input = [multi.values[0].cpu() for multi in inputs] # pyre-ignore[16] exported_model = torch._dynamo.export( - model, aten_graph=True, pre_dispatch=True, tracing_mode="fake" + model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake" )(*one_input) super().__init__(exported_model.graph_module) @@ -348,7 +344,7 @@ def configure_quantization_mode( combine_qparams_list_func, make_names_and_values_dict_func, skip_layer_func, - dyn_quant_func = None, + act_fake_quant_func = None, ): # these functions need to already be curried with all inputs other than weight, qparams @@ -373,7 +369,11 @@ def configure_quantization_mode( self.make_names_and_values_dict_func = make_names_and_values_dict_func # accepts [2d quantized tensor], [qparams], returns a dict of names, values to put in state_dict # note any final packing for storage should happen here - self.dyn_quant_func = dyn_quant_func + # `act_fake_quant_func` + if act_fake_quant_func is None: + self.act_fake_quant_func = lambda x: x + else: + self.act_fake_quant_func = act_fake_quant_func # accepts [activation tensor], returns a fake-quantized activation tensor return self def run(self): @@ -397,7 +397,7 @@ def get_quantized_state_dict(self): quantized_state_dict.pop(param_fqn) return quantized_state_dict - def call_function(self, target, args, kwargs, skip_quant=False): # noqa: C901 + def call_function(self, target, args, kwargs, already_quantized=False): # noqa: C901 def tensors_to_cuda(args): new_args = [] @@ -437,9 +437,11 @@ def tensors_to_cuda(args): quantize_linear = ( (target == aten.linear.default) # if its a linear and id(args[1]) in self.id_to_name # and if we know the layer name - and not skip_quant # and if we weren't told to skip quantization + # and we haven't already quantized this layer + and not already_quantized # and if the skip_layer_func doesn't say we should skip and not (self.skip_layer_func is not None and self.skip_layer_func(args[1])) + ) # then we will quantize this linear layer/weight if quantize_linear: # instantiate variables for GPTQ @@ -454,8 +456,7 @@ def tensors_to_cuda(args): quantize_linear ): # calculate H instead of output (will run the linear eventually with updated weight) x = cur_args[0].float() - if self.dyn_quant_func is not None: - x = self.dyn_quant_func(x) + x = self.act_fake_quant_func(x) shape = x.shape n = 1 if len(shape) == 2 else shape[0] H *= total_batches / (total_batches + n) @@ -465,9 +466,13 @@ def tensors_to_cuda(args): ).t().float() H += x.matmul(x.t()) else: + # weight has already been quantized but still need to apply + # activation quant for final calculation + if already_quantized: + cur_args = (self.act_fake_quant_func(cur_args[0]), *cur_args[1:]) + # get output if its not a linear out = super().call_function(target, cur_args, cur_kwargs) - if isinstance(out, torch.Tensor): outputs.append(out.cpu()) else: @@ -494,12 +499,12 @@ def tensors_to_cuda(args): # run linear with new weight to get corrected output new_out = self.call_function( - target, (args[0], DQ, *args[2:]), kwargs, skip_quant=True + target, (args[0], DQ, *args[2:]), kwargs, already_quantized=True ) if self.debug: old_out = self.call_function( - target, (args[0][:2], args[1], *args[2:]), kwargs, skip_quant=True + target, (args[0][:2], args[1], *args[2:]), kwargs, already_quantized=True ) def SQNR(x, y): @@ -513,7 +518,6 @@ def SQNR(x, y): print( "SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after) ) # matches - print( "SQNR for weight (can be low)", SQNR(W, DQ.cuda()) ) # fine to not match @@ -533,7 +537,7 @@ def SQNR(x, y): Q2 = self.quantize_func(W, qparams2) DQ2 = self.dequantize_func(Q2, qparams2).to(W.dtype) old_q_out = self.call_function( - target, (args[0][:2], DQ2, *args[2:]), kwargs, skip_quant=True + target, (args[0][:2], DQ2, *args[2:]), kwargs, already_quantized=True ) print( @@ -628,3 +632,741 @@ def faster_quant(self, H, W): all_qparams = self.combine_qparams_list_func(all_qparams) Q = self.quantize_func(DQ, all_qparams) return Q, DQ.to(orig_dtype), all_qparams + + +class GPTQQuantizer(Quantizer): + """ + This class implements a GPTQ Quantizer that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. + Unlike the base Quantizer class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement + __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. + + The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and + create_quantized_state_dict. Here is a description of each function. + + get_qparams_func: + A function that calculates the quantization qparams for an input tensor. + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + qparams: it can have any format but will need to be handled by the other defined functions below. + + quantize_func: + A function that applies quantization to an input tensor. It should be noted + that this function needs to be able to handle quantizing the entire weight tensor, a single group, + or a single column. + Args: + weight: A 2d weight tensor with non-integer dtype. + qparams: the output from get_qparams_func + Returns: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + + + dequantize_func: + A function that dequantizes an input quantized weight tensor. It should be noted + that this function needs to be able to handle dequantizing the entire weight tensor, a single group, + or a single column. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + weight: A 2d weight tensor with non-integer dtype. + + act_fake_quant_func (optional): + A function that (dynamically) quantizes activation to input + Args: + input: input Tensor in f32/bf16/f16 + Returns: + output: dynamically quantized and dequantized Tensor (with the same dtype as input) + + combine_qparams_list_func: + A function that combines several qparams into one qparam. + Args: + qparams_list: a list of qparams objects, each obtained by calling get_qparams_func + on a single group from a weight tensor + Returns: + qparams: an object of the same format as the qparams above. + + skip_layer_func: + A function that determines which linear layers should be skipped during GPTQ + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + skip: boolean indicating whether layer should be skipped + + make_names_and_values_dict_func: + A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they + should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the + corresponding quantized weights and qparams. + """ + + def __init__(self): + + assert self.get_qparams_func is not None + + assert self.quantize_func is not None + + assert self.dequantize_func is not None + + assert self.combine_qparams_list_func is not None + + # `make_names_and_values_dict_func`. + assert self.make_names_and_values_dict_func is not None + + @torch.no_grad() + def _create_quantized_state_dict( + self, + model, + inputs, + blocksize, + percdamp, + groupsize, + # `typing.Dict[, ]` to avoid runtime subscripting errors. + ) -> Dict: + print("Tracing model for GPTQ") + GPTQ_runner = GenericGPTQRunner( + model, + inputs, + blocksize, + percdamp, + groupsize, + ).configure_quantization_mode( + self.get_qparams_func, # pyre-ignore[16] + self.quantize_func, # pyre-ignore[16] + self.dequantize_func, # pyre-ignore[16] + self.combine_qparams_list_func, # pyre-ignore[16] + self.make_names_and_values_dict_func, # pyre-ignore[16] + self.skip_layer_func, # pyre-ignore[16] + self.act_fake_quant_func if hasattr(self, "act_fake_quant_func") else None, # pyre-ignore[16] + ) + print("Applying GPTQ to weights") + GPTQ_runner.run() + return GPTQ_runner.get_quantized_state_dict() + + def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module": + raise NotImplementedError("_convert_for_runtime not implemented") + + @torch.no_grad() + def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module: + pass + +def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): + k_divisible_by_groupsize = k % groupsize == 0 + if inner_k_tiles is not None: + k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0 + return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles + return k_divisible_by_groupsize + +def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm( + x.to(torch.bfloat16), + weight_int4pack, + groupsize, + scales_and_zeros.to(torch.bfloat16) + ).to(dtype=x.dtype) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + +class WeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, in_features: int, out_features: int, + bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, + ) -> None: + super().__init__() + self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) + if self.padding: + from model import find_multiple + self.origin_in_features = in_features + in_features = find_multiple(in_features, 1024) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" + self.register_buffer( + "weight", + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + ) + self.register_buffer( + "scales_and_zeros", + torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.padding: + import torch.nn.functional as F + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_int4( + input, + self.weight, self.scales_and_zeros, self.out_features, self.groupsize + ) + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None): + + for name, child in module.named_children(): + if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)): + if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: + setattr(module, name, WeightOnlyInt4Linear( + child.in_features, child.out_features, bias=False, + groupsize=groupsize, inner_k_tiles=inner_k_tiles, + )) + else: + replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func) + +class Int4WeightOnlyQuantizer(Quantizer): + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = True, + inner_k_tiles: Optional[int] = 8, + device: torch.device = torch.device("cuda"), + ) -> None: + super().__init__() + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + + self.inner_k_tiles = inner_k_tiles + self.groupsize: int = groupsize + self.padding_allowed: bool = padding_allowed + self.device: torch.device = device + + @torch.no_grad() + def _create_quantized_state_dict( + self, model: torch.nn.Module + ) -> Dict[str, torch.Tensor]: + cur_state_dict = model.state_dict() + for fqn, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + # assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + assert ( + in_features % self.groupsize == 0 + ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0" + + weight = mod.weight.data + if not _check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): + if self.padding_allowed: + from .utils import find_multiple + import torch.nn.functional as F + print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + else: + print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it") + continue + ( + w_int4x8, + scales_and_zeros + ) = groupwise_affine_quantize_tensor( + weight, + 4, # n_bit + self.groupsize, + ) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(self.device) + return cur_state_dict + + def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: + replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + self.padding_allowed, + ) + return model + + def quantize( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + state_dict = self._create_quantized_state_dict(model) + model = self._convert_for_runtime(model) + # TODO: make it strict + model.load_state_dict(state_dict, strict=False) + return model + +class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer): + def __init__( + self, + blocksize, + percdamp, + groupsize, + inner_k_tiles=8, + padding_allowed=True, + device: torch.device = torch.device("cuda"), + ): + self.blocksize = blocksize + self.percdamp = percdamp + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding_allowed = padding_allowed + self.device = device + self.act_fake_quant_func = None + n_bit = 4 + self.get_qparams_func = lambda w: get_groupwise_affine_qparams( + w, n_bit, groupsize + ) + self.quantize_func = lambda w, qparams: groupwise_affine_quantize_tensor_from_qparams( + w, qparams[0], qparams[1], n_bit, groupsize + ) + self.dequantize_func = lambda q, qparams: groupwise_affine_dequantize_tensor_from_qparams( + q, + qparams[0], + qparams[1], + n_bit, + groupsize, + ) + self.combine_qparams_list_func = lambda qparams_list: [ + torch.cat(x, dim=1) for x in zip(*qparams_list) + ] + # skip unless padding_allowed=True or its correctly sized + self.skip_layer_func = lambda linear_weight: not ( + _check_linear_int4_k(linear_weight.shape[-1], groupsize) or padding_allowed + ) + + # we need to do the padding here, both for q and the qparams if necessary + + # TODO: this is the gpt-fast version, merge with the main version later + def make_names_and_values_dict_func(q, qparams): + k = q.shape[1] + if not _check_linear_int4_k(k, groupsize): + new_k = find_multiple(k, 1024) + else: + new_k = k + # how much we need to pad the weight + delta_k = new_k - q.shape[1] + q = q.to(torch.int32).to(self.device) + final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) + scales = qparams[0].to(torch.bfloat16).to(self.device) + zeros = qparams[1].to(torch.bfloat16).to(self.device) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + # how many new groups we need for padded weight + delta_groups = new_k // groupsize - scales_and_zeros.shape[0] + final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) + return {"weight": final_q, "scales_and_zeros": final_s_and_z} + + self.make_names_and_values_dict_func = make_names_and_values_dict_func + super().__init__() + + def _convert_for_runtime(self, model): + replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + self.padding_allowed, + skip_layer_func=self.skip_layer_func, + ) + return model + + def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module: + state_dict = self._create_quantized_state_dict( + model, + inputs, + self.blocksize, + self.percdamp, + self.groupsize, + ) + model = self._convert_for_runtime(model) + model.load_state_dict(state_dict, strict=False) + return model + + +if TORCH_VERSION_AFTER_2_3: + from .quant_primitives import ( + get_group_qparams_symmetric, + group_quantize_tensor_symmetric, + per_token_dynamic_quant, + ) + + def linear_forward_8da4w( + x, + weight_int8, + scales, + zeros, + out_features, + groupsize, + precision, + ): + x = per_token_dynamic_quant(x) + # TODO: verify and remove following reshape code + # origin_x_size = x.size() + # x = x.reshape(-1, origin_x_size[-1]) + + # TODO: better API + # weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed) + n_bit = 4 + quant_min = -(2 ** (n_bit - 1)) + quant_max = 2 ** (n_bit - 1) - 1 + w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group( + weight_int8, + scales, + zeros, + quant_min, + quant_max, + torch.int8, + groupsize, + precision, + ) + + # x = x.to(torch.float16) + # w_dq = w_dq.to(torch.float16) + c = torch.nn.functional.linear(x, w_dq) + + # new_shape = origin_x_size[:-1] + (out_features,) + # c = c.reshape(new_shape) + + return c + + class Int8DynActInt4WeightLinear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + + in_features: int + out_features: int + weight: torch.Tensor + + """ + This module implements a dynamic quantized linear layer with int4 weight. + Weights are per channel groupwise quantized. Parameters of importance + groupsize: the number of elements in each quantized group + precision: precision of input and output. e.g. torch.float32 means input + activation is float32 and output is float32. + scales_precision: precision of per group scale. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias=True, + device=None, + dtype=None, + groupsize: int = 256, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + super().__init__() + # always pad if needed since it becomes a noop at runtime if not needed + # self.origin_in_features = in_features + assert ( + in_features % groupsize == 0 + ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" + # in_features = _calc_padded_size_linear_int4( + # in_features, groupsize + # ) + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + # TODO: align groupsize naming + self.groupsize = groupsize + # Precision of the activation which also indicates + # output precision of the dynamically quantized linear layer + # that his module represents. + self.precision = precision + + # currently storing unpacked int8 weights + self.register_buffer( + "weight", + torch.empty((out_features, in_features), dtype=torch.int8), + ) + self.register_buffer( + "scales", + torch.empty( + (out_features, in_features // groupsize), + dtype=scales_precision, + ), + ) + self.register_buffer( + "zeros", + torch.empty( + (out_features, in_features // groupsize), + dtype=scales_precision, + ), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(self.precision) + # padding is removed for perf + # input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_8da4w( + input, + self.weight, + self.scales, + self.zeros, + self.out_features, + self.groupsize, + self.precision, + ) + + + def replace_linear_8da4w( + module, + groupsize, + padding_allowed, + precision, + scales_precision, + ): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed: + setattr( + module, + name, + Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + precision=precision, + scales_precision=scales_precision, + ), + ) + else: + replace_linear_8da4w( + child, + groupsize, + padding_allowed, + precision, + scales_precision, + ) + + class Int8DynActInt4WeightQuantizer(Quantizer): + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = False, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + inner_k_tiles: Optional[int] = None, + _is_gpt_fast: bool = False, + ) -> None: + super().__init__() + if _is_gpt_fast: + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + else: + assert inner_k_tiles is None + self._is_gpt_fast = _is_gpt_fast + self.inner_k_tiles = inner_k_tiles + self.groupsize: int = groupsize + self.padding_allowed: bool = padding_allowed + self.precision: torch.dtype = precision + self.scales_precision: torch.dtype = scales_precision + + @torch.no_grad() + def _create_quantized_state_dict( + self, model: torch.nn.Module + ) -> Dict[str, torch.Tensor]: + cur_state_dict = model.state_dict() + for fqn, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + # assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + assert ( + in_features % self.groupsize == 0 + ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0" + + weight = mod.weight.data + if not _check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): + if self.padding_allowed: + from .utils import find_multiple + import torch.nn.functional as F + print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + else: + print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it") + continue + ( + weight_int8, + scales, + zeros, + ) = group_quantize_tensor_symmetric( + weight.to(self.precision), + 4, # n_bit + self.groupsize, + self.scales_precision, + ) + if self._is_gpt_fast: + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int8.to(torch.int32), self.inner_k_tiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") + else: + cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu") + cur_state_dict[f"{fqn}.scales"] = scales.to("cpu") + cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu") + # TODO: support bias? + + return cur_state_dict + + def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: + if self._is_gpt_fast: + # TODO: temporary path for gpt-fast, will remove later + replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + self.padding_allowed, + ) + else: + replace_linear_8da4w( + model, + self.groupsize, + self.padding_allowed, + self.precision, + self.precision, + ) + return model + + def quantize( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + state_dict = self._create_quantized_state_dict(model) + model = self._convert_for_runtime(model) + # TODO: make it strict + model.load_state_dict(state_dict, strict=False) + return model + + + class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer): + def __init__( + self, + blocksize, + percdamp, + groupsize, + inner_k_tiles=8, + padding_allowed=True, + precision=torch.float32, + _is_gpt_fast=False, + ): + self._is_gpt_fast = _is_gpt_fast + self.blocksize = blocksize + self.percdamp = percdamp + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding_allowed = padding_allowed + self.precision = precision + + self.act_fake_quant_func = per_token_dynamic_quant + n_bit = 4 + self.get_qparams_func = lambda w: get_group_qparams_symmetric( + w, n_bit, groupsize, self.precision + ) + quant_min = -(2 ** (n_bit - 1)) + quant_max = 2 ** (n_bit - 1) - 1 + + self.quantize_func = lambda w, qparams: torch.ops.quantized_decomposed.quantize_per_channel_group( + w, qparams[0], qparams[1], quant_min, quant_max, torch.int8, groupsize + ) + + self.dequantize_func = lambda q, qparams: torch.ops.quantized_decomposed.dequantize_per_channel_group( + q, + qparams[0], + qparams[1], + quant_min, + quant_max, + torch.int8, + groupsize, + self.precision, + ) + + self.combine_qparams_list_func = lambda qparams_list: [ + torch.cat(x, dim=1) for x in zip(*qparams_list) + ] + # skip unless padding_allowed=True or its correctly sized + + self.skip_layer_func = lambda linear_weight: not ( + _check_linear_int4_k(linear_weight.shape[-1], groupsize) or padding_allowed + ) + + # we need to do the padding here, both for q and the qparams if necessary + + # TODO: this is the gpt-fast version, merge with the main version later + def make_names_and_values_dict_func_gpt_fast(q, qparams): + k = q.shape[1] + new_k = find_multiple(k, 1024) + # how much we need to pad the weight + delta_k = new_k - q.shape[1] + q = q.to(torch.int32) + final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) + scales = qparams[0].to(torch.bfloat16) + zeros = qparams[1].to(torch.bfloat16) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + # how many new groups we need for padded weight + delta_groups = new_k // groupsize - scales_and_zeros.shape[0] + final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) + return {"weight": final_q, "scales_and_zeros": final_s_and_z} + + def make_names_and_values_dict_func(q, qparams): + k = q.shape[1] + new_k = find_multiple(k, 1 if groupsize is None else groupsize) + # how much we need to pad the weight + delta_k = new_k - q.shape[1] + final_q = F.pad(q, pad=(0, delta_k)) + scales = qparams[0].to(self.precision) + zeros = qparams[1].to(self.precision) + return {"weight": final_q, "scales": scales, "zeros": zeros} + + self.make_names_and_values_dict_func = make_names_and_values_dict_func_gpt_fast if self._is_gpt_fast else make_names_and_values_dict_func + super().__init__() + + def _convert_for_runtime(self, model): + if self._is_gpt_fast: + # TODO: temporary path for gpt-fast, will remove later + replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + self.padding_allowed, + ) + else: + replace_linear_8da4w( + model, + self.groupsize, + self.padding_allowed, + self.precision, + self.precision, + ) + return model + + def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module: + state_dict = self._create_quantized_state_dict( + model, + inputs, + self.blocksize, + self.percdamp, + self.groupsize, + ) + model = self._convert_for_runtime(model) + model.load_state_dict(state_dict, strict=False) + return model diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 80599cb71c..ab51dbb3a5 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -10,6 +10,8 @@ from .quant_primitives import * # noqa: F403 from .utils import * # noqa: F403 from .weight_only import * # noqa: F403 +from .unified import * +from .autoquant import * __all__ = [ "DynamicallyPerAxisQuantizedLinear", @@ -25,6 +27,9 @@ "dynamically_quantize_per_channel", "dequantize_per_tensor", "dequantize_per_channel", + "autoquant", + "change_linears_to_autoquantizable", + "change_autoquantizable_to_quantized", "quant_int8_dynamic_linear", "quant_int8_matmul", "quant_int8_dynamic_per_token_linear", @@ -41,4 +46,6 @@ "compute_error", "get_model_size_in_bytes", "WeightOnlyInt8QuantLinear", + "Int4WeightOnlyGPTQQuantizer", + "Int4WeightOnlyQuantizer", ] diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py new file mode 100644 index 0000000000..f1f387d7b5 --- /dev/null +++ b/torchao/quantization/autoquant.py @@ -0,0 +1,390 @@ +import torch +from .subclass import ( # noqa + Int8DynamicallyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, + QuantizedLinearWeightBase, +) +from torch.utils._python_dispatch import return_and_correct_aliasing +from .quant_primitives import ( + quantize_activation_per_token_absmax, + safe_int_mm, +) +import torch.nn.functional as F +from torch._inductor.utils import do_bench +aten = torch.ops.aten + +AUTOQUANT_CACHE = {} + +def check_cache(cls, shapes_and_dtype): + return AUTOQUANT_CACHE.get((cls,)+shapes_and_dtype, None) + +def update_cache(cls, shapes_and_dtype, res): + AUTOQUANT_CACHE[(cls,)+shapes_and_dtype] = res + +class AutoQuantizableLinearWeight(torch.Tensor): + """ + when run, finds best type of quantization for this tensor and swaps itself with that + """ + @staticmethod + def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): + kwargs["device"] = weight.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else weight.layout + ) + kwargs["dtype"] = ( + kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype + ) + kwargs["requires_grad"] = False + shape = kwargs.pop("shape", weight.shape) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): + self.weight = weight + self.qtensor_class_list = qtensor_class_list + self.logged_data = {} + self.mode = mode + + def __repr__(self): + return ( + f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})" + ) + + @staticmethod + def log_shape(act_mat, w_autoquant, bias): + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + logged_dtype = act_mat.dtype + logged_shapes = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape,) + shapes_and_dtype = logged_shapes + (logged_dtype,) + w_autoquant.logged_data[shapes_and_dtype] = 1 + w_autoquant.logged_data.get(shapes_and_dtype, 0) + for q_cls in w_autoquant.qtensor_class_list: + if check_cache(q_cls, shapes_and_dtype) is None: + update_cache(q_cls, shapes_and_dtype, None) + + def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): + act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype + if check_cache(q_cls, shapes_and_dtype) is None: + with torch.no_grad(): + act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device) + bias = None if bias_shape is None else torch.randn(bias_shape, dtype=act_dtype, device=self.device) + res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode) + update_cache(q_cls, shapes_and_dtype, res) + + def to_quantized(self, error_on_unseen, **kwargs): + if error_on_unseen and self.logged_data == {}: + raise RuntimeError("must run module normally to get shape, dtype info for autoquant") + elif (self.logged_data == {}) and not error_on_unseen: + # default back to non-quantized weight if not seen + self = AQFloatLinearWeight.from_float(self.weight) + return self + + + # only want to do shape+final print a single time if multiple layers + # see/have same shapes so we gate on check_cache being empty for + # at least one of the class/shape combinations. + do_final_print = False + print_once = True + + def count_shapes(self, do_print=True): + differe_shape_count=0 + for shapes_and_dtype, times_seen in self.logged_data.items(): + differe_shape_count += 1 + if do_print: + act_shape, weight_shape, bias_shape, dtype = shapes_and_dtype + print(f"activation_shapes: {act_shape}, times_seen: {times_seen}") + if do_print: + print(f"weight_shape: {weight_shape}, dtype: {dtype}, bias_shape: {bias_shape}") + return differe_shape_count + + # check each class + best_time = torch.inf + best_cls = None + for q_cls in self.qtensor_class_list: + # for each logged shape+dtype, benchmark + cur_time=0 + shape_count = count_shapes(self, do_print=False) + for shapes_and_dtype, times_seen in self.logged_data.items(): + if check_cache(q_cls, shapes_and_dtype) is None: + # only do final print if we have to autotune at least one cls/shape pair + do_final_print=True + + # only print shapes once + if print_once == True: + print_once = False + count_shapes(self, do_print=True) + + time_for_best_shape = check_cache(best_cls, shapes_and_dtype) + time_for_best_shape = torch.inf if time_for_best_shape is None else time_for_best_shape + self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape) + torch._dynamo.reset() + cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen + if shape_count is not None and shape_count > 1: + print(f">total_time: {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms") + if best_time >= cur_time: + best_time = cur_time + best_cls = q_cls + # only print if this is the first time seeing some cls+shape combo, + # otherwise we will print the same thing for every layer. + if do_final_print: + print(f"best_cls={best_cls}\n") + # TODO handle random cls args/kwargs? or should they be curried? + self = best_cls.from_float(self.weight) + return self + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), self.qtensor_class_list, dtype=self.dtype, mode=self.mode + ) + + def __tensor_flatten__(self): + return ["weight"], [self.qtensor_class_list, self.mode, self.dtype, self.shape] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + weight = tensor_data_dict["weight"] + qtensor_class_list, mode, dtype, shape = tensor_attributes[0] + return cls(weight, qtensor_class_list, mode, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) + + @classmethod + def from_float(cls, weight, qtensor_class_list, **kwargs): + return cls(weight, qtensor_class_list, **kwargs) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func is torch.nn.functional.linear: + mat1, w_autoquant, bias = ( + args[0], + args[1], + args[2] if len(args)>2 else None + ) + cls.log_shape(mat1, w_autoquant, bias) + return func(mat1, w_autoquant.weight, bias) + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except: + print(f"ERR: subclass doesn't implement {func}") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func is aten.detach.default: + return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) + +def do_autoquant_bench(op, *args, **kwargs): + """ + runs benchmark op(*args, **kwargs) avoiding torch.compile overhead + """ + rep = kwargs.pop("rep", 100) + warmup = kwargs.pop("warmup", 25) + with torch.no_grad(): + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + op(*args, **kwargs) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + op(*args, **kwargs) + res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") + return res + +def _is_interpolate_mode(mode): + if isinstance(mode, list) and mode[0]=="interpolate" and len(mode)==2 and isinstance(mode[1], float): + return True + return False + +class AQMixin(): + """ + Mixin to turn normal quantized subclasses into autoquantizable ones + """ + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + w_qtensor = cls.from_float(weight) + if _is_interpolate_mode(mode): + q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs") + else: + func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c)) + q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs") + res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100) + if res < best_time*1.1: + res2 = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900) + res=(res2*.9+res*.1) + print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") + return res + +class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight): + """ + AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight + """ + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + if not _is_interpolate_mode(mode): + return super()._autoquant_test(act_mat, weight, bias, best_time, mode) + + # SAM best is between .8 and 1, SDXL also performs best in this range + INTERPOLATION_CONSTANT = mode[1] + w_qtensor = cls.from_float(weight) + x_vals_int8, x_scales = quantize_activation_per_token_absmax( + act_mat.reshape(-1, act_mat.shape[-1]) + ) + quantized_matmul = ( + lambda x_vals_int8, x_scales, w_vals_int8: + safe_int_mm(x_vals_int8, w_vals_int8) * x_scales + ) + q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") + with torch.no_grad(): + res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) + print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") + + # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op + if res_matmul>=best_time: + return res_matmul + + # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT + to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul) + res = super()._autoquant_test(act_mat, weight, bias, to_beat) + max_int_const_win = (best_time-res_matmul)/(res-res_matmul) + res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul + print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") + return res_f + +class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight + """ + +class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that + uses a different kernel + """ + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + orig_dtype = act_mat.dtype + orig_shape = act_mat.shape + act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) + y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2) + y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales + if bias is not None: + y += bias + return y.to(orig_dtype) + + @classmethod + def _autoquant_test(cls, act_mat, *args): + # if act_mat has batchsize>2 don't use this kernel + if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>32: + return torch.inf + return super()._autoquant_test(act_mat, *args) + +class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that + uses a different kernel + """ + def _quantized_op(act_mat, w_qtensor, bias): + orig_shape = act_mat.shape + y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales) + y=y.reshape(*orig_shape[:-1], y.shape[-1]) + if bias is not None: + y += bias + return y + +class AQFloatLinearWeight(torch.Tensor, AQMixin): + """ + A class to be used in concert with AutoQuantizableLinearWeight to provide a + default/non-quantized option. Only implements the bare minimum needed to work with the + AutoQuantizableLinearWeight class using the same interfaces that would normally be + used by QTensor subclasses but for a default linear op instead. Result of from_float + is not a tensor subclass, but rather the float tensor. + """ + def __init__(self): + super().__init__() + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + return torch.nn.functional.linear(act_mat, w_qtensor, bias) + + @classmethod + def from_float(cls, weight): + return weight + +DEFAULT_CLASS_LIST = [ + AQFloatLinearWeight, + AQInt8DynamicallyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight2, + # AQWeightOnlyQuantizedLinearWeight3, + # TODO this gets picked in places where it makes perf worse, why? +] + +def change_linears_to_autoquantizable(model, **kwargs): + """ + Converts all linear weight tensors to the + AutoQuantizableLinearWeight tensor subclass. Expectation is that this is followed + by running the model and then calling change_autoquantizable_to_quantized + """ + from torchao.quantization.quant_api import _is_linear + filter_fn = kwargs.pop("filter_fn", _is_linear) + kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST) + kwargs["mode"] = kwargs.get("mode", ["relu", None]) + from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + from torchao.quantization.quant_api import _get_subclass_inserter + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs), + filter_fn if filter_fn is not None else _is_linear, + ) + +def change_autoquantizable_to_quantized(model, **kwargs): + """ + Converts AutoQuantizableLinearWeight tensor subclasses + to various quantized/non-quantized tensor subclasses depending + on benchmark results. Expectation is that these modules are + torch.compiled afterwards. + """ + hold = torch._dynamo.config.automatic_dynamic_shapes + torch._dynamo.config.automatic_dynamic_shapes = False + + filter_fn = kwargs.pop( + "filter_fn", + lambda mod, *args: + hasattr(mod, "weight") and isinstance(mod.weight, AutoQuantizableLinearWeight) + ) + error_on_unseen=kwargs.pop("error_on_unseen", True) + from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + from torchao.quantization.quant_api import _get_subclass_inserter + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter( + AutoQuantizableLinearWeight, method="to_quantized", error_on_unseen=error_on_unseen, **kwargs + ), + filter_fn, + ) + torch._dynamo.config.automatic_dynamic_shapes = hold + torch._dynamo.reset() + +@torch.no_grad() +def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **kwargs): + """ + Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape + across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer + and applies that type of quantization. + """ + if filter_fn is None: + from torchao.quantization.quant_api import _is_linear + filter_fn = _is_linear + + change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs) + if not isinstance(example_input, (tuple, list)): + assert isinstance(example_input, torch.Tensor) + example_input = [example_input] + model(*example_input) + change_autoquantizable_to_quantized(model, **kwargs) + return model + diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4194ceb9be..a830d52d78 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -15,15 +15,12 @@ and mixed GEMM kernels """ -import logging -from typing import Any, Dict, Tuple - import torch import torch.nn as nn import torch.nn.functional as F from .dynamic_quant import DynamicallyPerAxisQuantizedLinear -from .utils import TORCH_VERSION_AFTER_2_4 +from .utils import TORCH_VERSION_AFTER_2_3 from .subclass import ( Int4WeightOnlyQuantizedLinearWeight, @@ -32,11 +29,11 @@ QuantizedLinearWeightBase, ) from .weight_only import WeightOnlyInt8QuantLinear - -_AFTER_TORCH_2_4_ONLY = [ - "Int8DynActInt4WeightQuantizer", - "Int8DynActInt4WeightGPTQQuantizer", -] +from .unified import Quantizer, TwoStepQuantizer +from .GPTQ import ( + Int4WeightOnlyGPTQQuantizer, + Int4WeightOnlyQuantizer, +) __all__ = [ @@ -48,35 +45,20 @@ "swap_conv2d_1x1_to_linear", "Quantizer", "TwoStepQuantizer", -] + (_AFTER_TORCH_2_4_ONLY if TORCH_VERSION_AFTER_2_4 else []) - - -############################# Unified Quantization APIs ############################## -# API 1, single quantize call to create a quantized model with quantized state_dict -class Quantizer: - def quantize( - self, model: torch.nn.Module, *args: Any, **kwargs: Any - ) -> torch.nn.Module: - - pass - - -# API 2, flow that needs calibration or training -class TwoStepQuantizer: - def prepare( - self, model: torch.nn.Module, *args: Any, **kwargs: Any - ) -> torch.nn.Module: - - pass - - def convert( - self, model: torch.nn.Module, *args: Any, **kwargs: Any - ) -> torch.nn.Module: - - pass + "Int4WeightOnlyGPTQQuantizer", + "Int4WeightOnlyQuantizer" +] +if TORCH_VERSION_AFTER_2_3: + from .GPTQ import ( + Int8DynActInt4WeightQuantizer, + Int8DynActInt4WeightGPTQQuantizer, -############################# Unified Quantization APIs ############################## + ) + __all__ += [ + "Int8DynActInt4WeightQuantizer", + "Int8DynActInt4WeightGPTQQuantizer", + ] def _replace_with_custom_fn_if_matches_filter( @@ -136,10 +118,11 @@ def apply_dynamic_quant(model, filter_fn=None): def _get_subclass_inserter(cls, **kwargs): - + method = kwargs.pop("method", "from_float") def insert_subclass(lin): lin.weight = torch.nn.Parameter( - cls.from_float(lin.weight, **kwargs), requires_grad=False + # cls.from_float(...) + getattr(cls, method)(lin.weight, **kwargs), requires_grad=False ) return lin @@ -191,7 +174,6 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs): filter_fn, ) - def swap_conv2d_1x1_to_linear(model, filter_fn=None): """ Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized. @@ -222,600 +204,3 @@ def replace_conv2d_1x1(conv): _replace_with_custom_fn_if_matches_filter( model, replace_conv2d_1x1, filter_fn=filter_fn ) - - -if TORCH_VERSION_AFTER_2_4: - from .quant_primitives import ( - get_group_qparams_symmetric, - group_quantize_tensor_symmetric, - per_token_dynamic_quant, - ) - - from .GPTQ import lm_eval_available - - if lm_eval_available: - from .GPTQ import ( - evaluate, - GenericGPTQRunner, - get_task_dict, - InputRecorder, - lm_eval, - MultiInput, - ) - else: - logging.info("lm_eval not available, skip defining GPTQQuantizer") - - - class GPTQQuantizer(Quantizer): - """ - This class implements a GPTQ Quantizer that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. - Unlike the base Quantizer class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement - __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. - - The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and - create_quantized_state_dict. Here is a description of each function. - - get_qparams_func: - A function that calculates the quantization qparams for an input tensor. - Args: - weight: A 2d weight tensor with non-integer dtype. - Returns: - qparams: it can have any format but will need to be handled by the other defined functions below. - - quantize_func: - A function that applies quantization to an input tensor. It should be noted - that this function needs to be able to handle quantizing the entire weight tensor, a single group, - or a single column. - Args: - weight: A 2d weight tensor with non-integer dtype. - qparams: the output from get_qparams_func - Returns: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - - - dequantize_func: - A function that dequantizes an input quantized weight tensor. It should be noted - that this function needs to be able to handle dequantizing the entire weight tensor, a single group, - or a single column. - Args: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - qparams: the output from get_qparams_func - Returns: - weight: A 2d weight tensor with non-integer dtype. - - dyn_quant_func (optional): - A function that dynamically quantizes inputs - Args: - input: input Tensor in f32/bf16/f16 - Returns: - output: dynamically quantized and dequantized Tensor (with the same dtype as input) - - combine_qparams_list_func: - A function that combines several qparams into one qparam. - Args: - qparams_list: a list of qparams objects, each obtained by calling get_qparams_func - on a single group from a weight tensor - Returns: - qparams: an object of the same format as the qparams above. - - skip_layer_func: - A function that determines which linear layers should be skipped during GPTQ - Args: - weight: A 2d weight tensor with non-integer dtype. - Returns: - skip: boolean indicating whether layer should be skipped - - make_names_and_values_dict_func: - A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they - should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. - Args: - quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) - qparams: the output from get_qparams_func - Returns: - names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the - corresponding quantized weights and qparams. - """ - - def __init__(self): - - assert self.get_qparams_func is not None - - assert self.quantize_func is not None - - assert self.dequantize_func is not None - - assert self.combine_qparams_list_func is not None - - # `make_names_and_values_dict_func`. - assert self.make_names_and_values_dict_func is not None - - @staticmethod - def get_inputs( - model, - tokenizer, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs, - ) -> "MultiInput": - input_recorder = InputRecorder( - model, - tokenizer, - calibration_seq_length, - pad_calibration_inputs, - ) - - try: - - lm_eval.tasks.initialize_tasks() - except: - pass - - task_dict = get_task_dict(calibration_tasks) - print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) - - evaluate( - input_recorder, - task_dict, - limit=calibration_limit, - ) - inputs = input_recorder.get_recorded_inputs() - assert inputs is not None, ( - f"No inputs were collected, use a task other than {calibration_tasks}, " - + "use option pad_calibration_inputs, or decrease calibration_sequence_length (currently " - + f"{calibration_seq_length})" - ) - print(f"Obtained {len(inputs[0].values)} calibration samples") - return inputs - - @torch.no_grad() - def _create_quantized_state_dict( - self, - model, - tokenizer, - blocksize, - percdamp, - groupsize, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs, - # `typing.Dict[, ]` to avoid runtime subscripting errors. - ) -> Dict: - inputs = GPTQQuantizer.get_inputs( - model, - tokenizer, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs, - ) - print("Tracing model for GPTQ") - GPTQ_runner = GenericGPTQRunner( - model, - inputs, - blocksize, - percdamp, - groupsize, - ).configure_quantization_mode( - self.get_qparams_func, # pyre-ignore[16] - self.quantize_func, # pyre-ignore[16] - self.dequantize_func, # pyre-ignore[16] - self.combine_qparams_list_func, # pyre-ignore[16] - self.make_names_and_values_dict_func, # pyre-ignore[16] - self.skip_layer_func, # pyre-ignore[16] - self.dyn_quant_func if hasattr(self, "dyn_quant_func") else None, # pyre-ignore[16] - ) - print("Applying GPTQ to weights") - GPTQ_runner.run() - return GPTQ_runner.get_quantized_state_dict() - - def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module": - raise NotImplementedError("_convert_for_runtime not implemented") - - @torch.no_grad() - def quantize(self, model: torch.nn.Module, **kwargs: Any) -> torch.nn.Module: - state_dict = self._create_quantized_state_dict( - model, - self.tokenizer, - self.blocksize, - self.percdamp, - self.groupsize, - self.calibration_tasks, - self.calibration_limit, - self.calibration_seq_length, - self.pad_calibration_inputs, - ) - model = self._convert_for_runtime(model) - model.load_state_dict(state_dict, strict=False) - return model - - - def linear_forward_8da4w( - x, - weight_int8, - scales, - zeros, - out_features, - group_size, - precision, - ): - x = per_token_dynamic_quant(x) - # TODO: verify and remove following reshape code - # origin_x_size = x.size() - # x = x.reshape(-1, origin_x_size[-1]) - - # TODO: better API - # weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed) - n_bit = 4 - quant_min = -(2 ** (n_bit - 1)) - quant_max = 2 ** (n_bit - 1) - 1 - w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group( - weight_int8, - scales, - zeros, - quant_min, - quant_max, - torch.int8, - group_size, - precision, - ) - - # x = x.to(torch.float16) - # w_dq = w_dq.to(torch.float16) - c = torch.nn.functional.linear(x, w_dq) - - # new_shape = origin_x_size[:-1] + (out_features,) - # c = c.reshape(new_shape) - - return c - - - class Int8DynActInt4WeightLinear(torch.nn.Module): - __constants__ = ["in_features", "out_features"] - - in_features: int - out_features: int - weight: torch.Tensor - - """ - This module implements a dynamic quantized linear layer with int4 weight. - Weights are per channel groupwise quantized. Parameters of importance - group_size: the number of elements in each quantized group - precision: precision of input and output. e.g. torch.float32 means input - activation is float32 and output is float32. - scales_precision: precision of per group scale. - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias=True, - device=None, - dtype=None, - group_size: int = 256, - precision: torch.dtype = torch.float32, - scales_precision: torch.dtype = torch.float32, - ) -> None: - super().__init__() - # always pad if needed since it becomes a noop at runtime if not needed - # self.origin_in_features = in_features - assert ( - in_features % group_size == 0 - ), f"require in_features:{in_features} % group_size:{group_size} == 0" - # in_features = _calc_padded_size_linear_int4( - # in_features, group_size - # ) - self.in_features = in_features - self.out_features = out_features - assert not bias, "require bias=False" - # TODO: align groupsize naming - self.group_size = group_size - # Precision of the activation which also indicates - # output precision of the dynamically quantized linear layer - # that his module represents. - self.precision = precision - - # currently storing unpacked int8 weights - self.register_buffer( - "weight", - torch.empty((out_features, in_features), dtype=torch.int8), - ) - self.register_buffer( - "scales", - torch.empty( - (out_features, in_features // group_size), - dtype=scales_precision, - ), - ) - self.register_buffer( - "zeros", - torch.empty( - (out_features, in_features // group_size), - dtype=scales_precision, - ), - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = input.to(self.precision) - # padding is removed for perf - # input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) - return linear_forward_8da4w( - input, - self.weight, - self.scales, - self.zeros, - self.out_features, - self.group_size, - self.precision, - ) - - - from functools import reduce - from math import gcd - - - def find_multiple(n: int, *args: Tuple[int]) -> int: - # TODO: this change is reverted right now in gpt-fast - k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9] - if n % k == 0: - return n - return n + k - (n % k) - - - def _check_linear_int4_k(k, group_size=1): - return k % group_size == 0 - - - def _calc_padded_size_linear_int4(k, groupsize=1): - return find_multiple(k, groupsize) - - - def pack_scales_and_zeros(scales, zeros, precision=torch.float32): - assert scales.shape == zeros.shape - assert scales.dtype == precision - assert zeros.dtype == precision - return ( - torch.cat( - [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), - ], - 2, - ) - .transpose(0, 1) - .contiguous() - ) - - - def unpack_scales_and_zeros(scales_and_zeros): - assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - assert scales_and_zeros.dtype == torch.float - return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) - - - def replace_linear_8da4w( - module, - group_size, - padding_allowed, - precision, - scales_precision, - ): - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - if _check_linear_int4_k(child.in_features, group_size) or padding_allowed: - setattr( - module, - name, - Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - group_size=group_size, - precision=precision, - scales_precision=scales_precision, - ), - ) - else: - replace_linear_8da4w( - child, - group_size, - padding_allowed, - precision, - scales_precision, - ) - - - class Int8DynActInt4WeightQuantizer(Quantizer): - def __init__( - self, - group_size: int = 256, - padding_allowed: bool = False, - precision: torch.dtype = torch.float32, - scales_precision: torch.dtype = torch.float32, - ) -> None: - self.group_size: int = group_size - self.padding_allowed: bool = padding_allowed - self.precision: torch.dtype = precision - self.scales_precision: torch.dtype = scales_precision - # assert group_size in [32, 64, 128, 256] - - @torch.no_grad() - def _create_quantized_state_dict( - self, model: torch.nn.Module - ) -> Dict[str, torch.Tensor]: - cur_state_dict = model.state_dict() - for fqn, mod in model.named_modules(): - if isinstance(mod, torch.nn.Linear): - assert not mod.bias - out_features = mod.out_features - in_features = mod.in_features - # assert out_features % 8 == 0, "require out_features % 8 == 0" - print(f"linear: {fqn}, in={in_features}, out={out_features}") - - assert ( - in_features % self.group_size == 0 - ), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0" - - weight = mod.weight.data - """ - if not _check_linear_int4_k( - in_features, self.group_size - ): - if self.padding_allowed: - print( - f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" - ) - padded_in_features = _calc_padded_size_linear_int4( - in_features, self.group_size - ) - weight = F.pad( - weight, pad=(0, padded_in_features - in_features) - ) - else: - raise RuntimeError( - f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " - + "and that group_size" - ) - """ - ( - weight_int8, - scales, - zeros, - ) = group_quantize_tensor_symmetric( - weight.to(self.precision), - 4, # n_bit - self.group_size, - self.scales_precision, - ) - cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu") - cur_state_dict[f"{fqn}.scales"] = scales.to("cpu") - cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu") - # TODO: support bias? - - return cur_state_dict - - def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: - replace_linear_8da4w( - model, - self.group_size, - self.padding_allowed, - self.precision, - self.scales_precision, - ) - return model - - def quantize( - self, model: torch.nn.Module, *args: Any, **kwargs: Any - ) -> torch.nn.Module: - state_dict = self._create_quantized_state_dict(model) - model = self._convert_for_runtime(model) - # TODO: make it strict - model.load_state_dict(state_dict, strict=False) - return model - - - class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer): - - def __init__( - self, - tokenizer, - blocksize, - percdamp, - groupsize, - calibration_tasks, - calibration_limit, - calibration_seq_length, - pad_calibration_inputs, - inner_k_tiles=8, - padding_allowed=True, - precision=torch.float32, - ): - - self.tokenizer = tokenizer - - self.blocksize = blocksize - - self.percdamp = percdamp - - self.groupsize = groupsize - - self.calibration_tasks = calibration_tasks - - self.calibration_limit = calibration_limit - - self.calibration_seq_length = calibration_seq_length - - self.pad_calibration_inputs = pad_calibration_inputs - - self.inner_k_tiles = inner_k_tiles - - self.padding_allowed = padding_allowed - - self.precision = precision - - self.dyn_quant_func = per_token_dynamic_quant - n_bit = 4 - - self.get_qparams_func = lambda w: get_group_qparams_symmetric( - w, n_bit, groupsize, self.precision - ) - quant_min = -(2 ** (n_bit - 1)) - quant_max = 2 ** (n_bit - 1) - 1 - - self.quantize_func = lambda w, qparams: torch.ops.quantized_decomposed.quantize_per_channel_group( - w, qparams[0], qparams[1], quant_min, quant_max, torch.int8, groupsize - ) - - self.dequantize_func = lambda q, qparams: torch.ops.quantized_decomposed.dequantize_per_channel_group( - q, - qparams[0], - qparams[1], - quant_min, - quant_max, - torch.int8, - groupsize, - self.precision, - ) - - self.combine_qparams_list_func = lambda qparams_list: [ - torch.cat(x, dim=1) for x in zip(*qparams_list) - ] - # skip unless padding_allowed=True or its correctly sized - - self.skip_layer_func = lambda linear_weight: not ( - _check_linear_int4_k(linear_weight.shape[-1], groupsize) or padding_allowed - ) - - # we need to do the padding here, both for q and the qparams if necessary - - def make_names_and_values_dict_func(q, qparams): - k = q.shape[1] - new_k = _calc_padded_size_linear_int4(k, groupsize) - # how much we need to pad the weight - delta_k = new_k - q.shape[1] - final_q = F.pad(q, pad=(0, delta_k)) - scales = qparams[0].to(self.precision) - zeros = qparams[1].to(self.precision) - # scales_and_zeros = pack_scales_and_zeros(*qparams, precision=self.precision) - # how many new groups we need for padded weight - # delta_groups = new_k // groupsize - scales_and_zeros.shape[0] - # TODO: split scales and zero_points - # final_s_and_z = F.pad( - # scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1 - # ) - return {"weight": final_q, "scales": scales, "zeros": zeros} - - self.make_names_and_values_dict_func = make_names_and_values_dict_func - super().__init__() - - def _convert_for_runtime(self, model): - replace_linear_8da4w( - model, - self.groupsize, - self.padding_allowed, - self.precision, - self.precision, - ) - return model diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index c8ff618154..88eafd4b2a 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -11,10 +11,11 @@ from torch.library import impl from torchao.kernel.intmm import int_scaled_matmul -from .utils import TORCH_VERSION_AFTER_2_4 +from torchao.kernel.intmm import safe_int_mm +from .utils import TORCH_VERSION_AFTER_2_3 -_AFTER_TORCH_2_4_ONLY = [ +_AFTER_TORCH_2_3_ONLY = [ "per_token_dynamic_quant", "get_group_qparams_symmetric", ] @@ -38,66 +39,10 @@ "groupwise_affine_quantize_tensor", "groupwise_affine_dequantize_tensor", # TODO: need to clean up above functions -] + (_AFTER_TORCH_2_4_ONLY if TORCH_VERSION_AFTER_2_4 else []) - - -def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: - r""" - This function wraps torch._int_mm and avoids several undesirable behaviors of the function for certain inputs while still - returning correct results and being torch.compiled in a performant way. - - Assumes both tensors have dimension of 2. - - Note: no error checking for torch.compiled path, if input.shape = [i, j] and j<=16 then the triton kernel - will error. - - Args: - input (Tensor, int8): the first tensor to be multiplied - mat2 (Tensor, int8): the second tensor to be multiplied - - Return: - out (Tensor, int32): the result of the matmul with device matching that of the inputs - """ - # torch.compile path - if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - - # error checking for cublas path - assert ( - mat2.device == input.device - ), f"need both tensors to be on the same device but got {mat2.device} and {input.device}" - device_cpu = "cpu" in [mat2.device.type, input.device.type] - # with input.shape = [i,j] and mat2.shape = [j,k] - i_is_strictly_greater_than_16 = input.shape[0] > 16 - j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) - k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) - bad_dimensions_for_cublas = not ( - i_is_strictly_greater_than_16 - and j_is_nonzero_multiple_of_8 - and k_is_nonzero_multiple_of_8 - ) - - if device_cpu or bad_dimensions_for_cublas: - # fallback path - return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( - input.device.type - ) - - # cublas paths - if not mat2.is_contiguous(): # silently gives incorrect result without this - mat2 = mat2.contiguous() - if (not input.is_contiguous()) and ( - input.shape[0] % 8 != 0 - ): # gives cryptic error without this - input = ( - input.contiguous() - ) # (it seems the transpose makes cublas check the above j constraint on i) - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - +] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else []) # copy-pasta of https://www.internalfb.com/intern/anp/view/?id=3350736 - def dynamically_quantize_per_tensor( x, quant_min, @@ -438,7 +383,6 @@ def pack_tinygemm_scales_and_zeros(scales, zeros): def unpack_tinygemm_scales_and_zeros(scales_and_zeros): assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - assert scales_and_zeros.dtype == torch.float return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) @@ -525,6 +469,7 @@ def groupwise_affine_dequantize_tensor( ) +# TODO: replace this with torch.ao.quantization.PerChannelMinMaxObserver def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float32): # needed for GPTQ with padding if groupsize > w.shape[-1]: @@ -554,24 +499,7 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float ) -def pack_scales_and_zeros(scales, zeros, precision=torch.float16): - assert scales.shape == zeros.shape - assert scales.dtype == precision - assert zeros.dtype == precision - return ( - torch.cat( - [ - scales.reshape(scales.size(0), scales.size(1), 1), - zeros.reshape(zeros.size(0), zeros.size(1), 1), - ], - 2, - ) - .transpose(0, 1) - .contiguous() - ) - - -if TORCH_VERSION_AFTER_2_4: +if TORCH_VERSION_AFTER_2_3: def group_quantize_tensor_symmetric( w, n_bit=4, @@ -645,4 +573,4 @@ def per_token_dynamic_quant(input: torch.Tensor) -> torch.Tensor: input = torch.ops.quantized_decomposed.dequantize_per_token( input, scales, zero_points, quant_min, quant_max, torch.int8, orig_dtype ) - return input + return input.to(orig_dtype) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 537099f67a..64689b8d95 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -301,7 +301,7 @@ def from_float(cls, input_float, qmin=-128, qmax=127): # however the external representation of our tensor will maintain the correct # shape attribute which needs to be tracked directly. int_data = w_int_repr.contiguous().t() - if cls is not Int8DynamicallyQuantizedLinearWeight: + if not issubclass(cls, Int8DynamicallyQuantizedLinearWeight): int_data = int_data.contiguous() return cls( int_data, w_scales, False, input_float.shape, dtype=input_float.dtype diff --git a/torchao/quantization/unified.py b/torchao/quantization/unified.py new file mode 100644 index 0000000000..16112ac0f0 --- /dev/null +++ b/torchao/quantization/unified.py @@ -0,0 +1,29 @@ +import torch +from typing import Any + +############################# Unified Quantization APIs ############################## +# API 1, single quantize call to create a quantized model with quantized state_dict +class Quantizer: + def quantize( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + + pass + + +# API 2, flow that needs calibration or training +class TwoStepQuantizer: + def prepare( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + + pass + + def convert( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + + pass + + +############################# Unified Quantization APIs ############################## diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index e20ed6cfc5..a178edf125 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -3,22 +3,26 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import torch from torch.utils._python_dispatch import TorchDispatchMode from packaging import version +from functools import reduce +from math import gcd + __all__ = [ "find_multiple", "compute_error", "_apply_logging_hook", "get_model_size_in_bytes", - "TORCH_VERSION_AFTER_2_4", + "TORCH_VERSION_AFTER_2_3", ] -def find_multiple(n: int, k: int) -> int: +def find_multiple(n: int, *args: Tuple[int]) -> int: + k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9] if n % k == 0: return n return n + k - (n % k) @@ -95,8 +99,17 @@ def get_model_size_in_bytes(model): s += b.nelement() * b.element_size() return s - if version.parse(torch.__version__) >= version.parse("2.4.0.dev"): TORCH_VERSION_AFTER_2_4 = True else: TORCH_VERSION_AFTER_2_4 = False + +if version.parse(torch.__version__) >= version.parse("2.3.0.dev"): + TORCH_VERSION_AFTER_2_3 = True +else: + TORCH_VERSION_AFTER_2_3 = False + +if version.parse(torch.__version__) >= version.parse("2.2.0.dev"): + TORCH_VERSION_AFTER_2_2 = True +else: + TORCH_VERSION_AFTER_2_2 = False diff --git a/torchao/quantization/weight_only.py b/torchao/quantization/weight_only.py index 2ab5adf3d1..099df0f17f 100644 --- a/torchao/quantization/weight_only.py +++ b/torchao/quantization/weight_only.py @@ -22,9 +22,8 @@ def __init__(self, *args, **kwargs): scales = kwargs.pop("scales") super().__init__(*args, **kwargs) - self.w_int8 = w_int8 - - self.scales = scales + self.register_buffer("w_int8", w_int8) + self.register_buffer("scales", scales) def forward(self, x, *args, **kwargs): """