Skip to content

Commit

Permalink
Refactor the API for quant method argument for quantize function
Browse files Browse the repository at this point in the history
Summary:
Addressing feedback from #384 and #375

Test Plan:
regression tests

python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jun 19, 2024
1 parent e5ee771 commit e27b898
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 109 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ All with no intrusive code changes and minimal accuracy degradation.
Quantizing your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/) and a HuggingFace inference example [here](scripts/hf_eval.py)

```python
from torchao.quantization.quant_api import quantize
m = quantize(m, "int4wo")
from torchao.quantization.quant_api import quantize, int4_weight_only
m = quantize(m, int4_weight_only())
```

Benchmarks are run on a machine with a single A100 GPU using the script in `_models/llama` which generates text in a latency-optimized way (batchsize=1)
Expand Down
6 changes: 3 additions & 3 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
TestCase,
run_tests,
)
from torchao.quantization.quant_api import int4wo
from torchao.quantization.quant_api import int4_weight_only
import torch
import unittest

Expand All @@ -12,8 +12,8 @@ class TestAffineQuantized(TestCase):
def test_tensor_core_layout_transpose(self):
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
shape = t.shape
apply_int4wo_quant = int4wo(groupsize=32)
aqt = apply_int4wo_quant(t)
apply_int4_weight_only_quant = int4_weight_only(groupsize=32)
aqt = apply_int4_weight_only_quant(t)
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

Expand Down
18 changes: 9 additions & 9 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
DynamicallyPerAxisQuantizedLinear,
)
from torchao.quantization.quant_api import (
int4wo,
int8wo,
int8da_int8w,
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
quantize,
_replace_with_custom_fn_if_matches_filter,
)
Expand Down Expand Up @@ -98,21 +98,21 @@

def _int8wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8wo())
quantize(mod, int8_weight_only())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_woqtensors(mod)

def _int8da_int8w_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8da_int8w())
quantize(mod, int8_dynamic_activation_int8_weight())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_dqtensors(mod)

def _int4wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int4wo())
quantize(mod, int4_weight_only())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod)
Expand Down Expand Up @@ -832,7 +832,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):

def api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int4wo(**kwargs))
quantize(mod, int4_weight_only(**kwargs))
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
Expand All @@ -853,7 +853,7 @@ def test_dynamic_quant(self):
m = nn.Sequential(nn.Linear(K, N))

y_ref = m(x)
quantize(m, int8da_int8w())
quantize(m, int8_dynamic_activation_int8_weight())
y_test = m(x)

sqnr = compute_error(y_ref, y_test)
Expand Down Expand Up @@ -1436,7 +1436,7 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype):
api(model)
size2 = torchao.utils.get_model_size_in_bytes(model)
self.assertTrue(size2 < size)




Expand Down
51 changes: 18 additions & 33 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
_replace_with_custom_fn_if_matches_filter,
Quantizer,
TwoStepQuantizer,
int8da_int4w,
int4wo,
int8wo,
int8da_int8w,
int8_dynamic_activation_int4_weight,
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
)
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
Expand Down Expand Up @@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:

class TorchCompileDynamicQuantizer(Quantizer):
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
quantize(model, int8da_int8w())
quantize(model, int8_dynamic_activation_int8_weight())
return model

class ToyLinearModel(torch.nn.Module):
Expand Down Expand Up @@ -152,7 +152,7 @@ class TestQuantFlow(TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantize(m, int8da_int8w())
m = quantize(m, int8_dynamic_activation_int8_weight())
quantized = m(*example_inputs)
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self):
)
m = ToyLinearModel().eval().cpu()
def api(model):
model = quantize(model, int8wo())
model = quantize(model, int8_weight_only())
unwrap_tensor_subclass(model)

api(m)
Expand Down Expand Up @@ -335,7 +335,7 @@ def test_8da4w_quantizer_eval(self):
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_int4wo(self):
def test_gptq_quantizer_int4_weight_only(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
torchao._models.llama.model.use_index_put_for_kv_cache = True
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_gptq_quantizer_int4wo(self):
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_quantizer_int4wo(self):
def test_quantizer_int4_weight_only(self):
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
from torchao._models._eval import TransformerEvalWrapper
precision = torch.bfloat16
Expand Down Expand Up @@ -499,11 +499,11 @@ def test_eval_wrapper_llama3(self):
# TODO: move to a separate test file
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_8da4w(self):
groupsize = 32
group_size = 32
m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
m = quantize(m, int8da_int4w(groupsize=groupsize))
m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size))

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand All @@ -514,7 +514,7 @@ def test_quantized_tensor_subclass_8da4w(self):
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

quantizer = Int8DynActInt4WeightQuantizer(groupsize=groupsize)
quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
m_copy = quantizer.quantize(m_copy)
assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear)
assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear)
Expand All @@ -531,13 +531,13 @@ def test_quantized_tensor_subclass_int4(self):
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")

groupsize = 32
m = quantize(m, int4wo(groupsize=groupsize))
group_size = 32
m = quantize(m, int4_weight_only(group_size=group_size))
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size)

res = m(*example_inputs)
ref = m_copy(*example_inputs)
Expand All @@ -552,7 +552,7 @@ def test_quantized_tensor_subclass_int8_wo(self):
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

m = quantize(m, int8wo())
m = quantize(m, int8_weight_only())

assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
Expand All @@ -575,7 +575,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
m = quantize(m, int8da_int8w())
m = quantize(m, int8_dynamic_activation_int8_weight())

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand All @@ -602,29 +602,14 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
# make sure it compiles
torch._export.aot_compile(m_unwrapped, example_inputs)

def test_register_apply_tensor_subclass(self):
from torchao import register_apply_tensor_subclass
def apply_my_dtype(weight):
return weight * 2

m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
with self.assertRaisesRegex(ValueError, "not supported"):
quantize(m, "my_dtype")

register_apply_tensor_subclass("my_dtype", apply_my_dtype)
# make sure it runs
quantize(m, "my_dtype")
m(*example_inputs)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_save_load(self):
m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16)

m = quantize(m, "int8_weight_only")
m = quantize(m, int8_weight_only())
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
Expand Down
2 changes: 0 additions & 2 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@
from torchao.quantization import (
autoquant,
quantize,
register_apply_tensor_subclass,
)
from . import dtypes

__all__ = [
"dtypes",
"autoquant",
"quantize",
"register_apply_tensor_subclass",
]
2 changes: 1 addition & 1 deletion torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .nf4tensor import NF4Tensor, to_nf4
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
from .uint4 import UInt4Tensor
from .aqt import AffineQuantizedTensor, to_affine_quantized
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized

__all__ = [
"NF4Tensor",
Expand Down
File renamed without changes.
22 changes: 10 additions & 12 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ from torch._inductor.runtime.runtime_utils import do_bench_gpu
import copy
from torchao.quantization.quant_api import (
quantize,
int4wo,
int4_weight_only,
)

class ToyLinearModel(torch.nn.Module):
Expand All @@ -102,8 +102,8 @@ example_inputs = m.example_inputs(dtype=dtype, device="cuda")

m_bf16 = torch.compile(m_bf16, mode='max-autotune')
# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao)
groupsize = 32
m = quantize(m, int4wo(groupsize=groupsize))
group_size = 32
m = quantize(m, int4_weight_only(group_size=group_size))

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
Expand Down Expand Up @@ -150,7 +150,7 @@ for n, m in model.named_modules():
The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support
`torch.export.export` and `torch.aot_compile` with the following workaround:
```
from torchao.quantization.utils import unwrap_tensor_subclass
from torchao.utils import unwrap_tensor_subclass
m_unwrapped = unwrap_tensor_subclass(m)
Expand All @@ -167,11 +167,10 @@ torch._export.aot_compile(m_unwrapped, example_inputs)
```python
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
torch._inductor.config.force_fuse_int_mm_with_mul = True
from torchao.quantization import quant_api

# for torch 2.4+
from torchao.quantization.quant_api import quantize
quantize(model, "int8_dynamic")
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
quantize(model, int8_dynamic_activation_int8_weight())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
Expand All @@ -182,9 +181,8 @@ change_linear_weights_to_int8_dqtensors(model)

```python
# for torch 2.4+
from torchao.quantization.quant_api import quantize
from torchao.quantization.quant_api import int8wo
quantize(model, "int8_weight_only")
from torchao.quantization import quantize, int8_weight_only
quantize(model, int8_weight_only())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
Expand All @@ -198,8 +196,8 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is

```python
# for torch 2.4+
from torchao.quantization.quant_api import quantize
quantize(model, "int4_weight_only")
from torchao.quantization import quantize, int4_weight_only
quantize(model, int4_weight_only())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
Expand Down
5 changes: 4 additions & 1 deletion torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,8 @@
"dequantize_affine",
"choose_qprams_affine",
"quantize",
"register_apply_tensor_subclass",
"int8_dynamic_act_int4_weight",
"int8_dynamic_act_int8_weight",
"int4_weight_only",
"int8_weight_only",
]
Loading

0 comments on commit e27b898

Please sign in to comment.