Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register pt2e static quantization #1761

Merged
merged 18 commits into from
May 9, 2024
3 changes: 3 additions & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer
10 changes: 6 additions & 4 deletions neural_compressor/torch/algorithms/pt2e_quant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
from torch.fx.graph_module import GraphModule

from neural_compressor.common.utils import logger
from neural_compressor.torch.algorithms.base_algorithm import Quantizer
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version


class W8A8StaticQuantizer:
class W8A8StaticQuantizer(Quantizer):

@staticmethod
def update_quantizer_based_on_quant_config(quantizer: X86InductorQuantizer, quant_config) -> X86InductorQuantizer:
Expand Down Expand Up @@ -69,21 +70,22 @@ def export_model(
logger.error(f"Failed to export the model: {e}")
return exported_model

def prepare(
self, model: torch.nn.Module, quant_config, example_inputs: Tuple[Any], *args: Any, **kwargs: Any
) -> GraphModule:
def prepare(self, model: torch.nn.Module, example_inputs, inplace=True, *args, **kwargs) -> GraphModule:
"""Prepare the model for calibration.

There are two steps in this process:
1) export the eager model into model with Aten IR.
2) create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
"""
quant_config = self.quant_config
assert isinstance(example_inputs, tuple), f"Expected `example_inputs` to be a tuple, got {type(example_inputs)}"
# Set the model to eval mode
model = model.eval()

# 1) Capture the FX Graph to be quantized
dynamic_shapes = kwargs.get("dynamic_shapes", None)
if quant_config is not None:
dynamic_shapes = quant_config.dynamic_shapes
exported_model = self.export_model(model, example_inputs, dynamic_shapes=dynamic_shapes)
logger.info("Exported the model to Aten IR successfully.")
if exported_model is None:
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
FP8Config,
get_default_fp8_config,
get_default_fp8_config_set,
PT2EStaticQuantConfig,
get_default_pt2e_static_config,
)

from neural_compressor.torch.quantization.autotune import (
Expand Down
20 changes: 20 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TEQConfig,
)
from neural_compressor.torch.utils import Mode, logger, register_algo
from neural_compressor.torch.utils.constants import PT2E_STATIC_QUANT


###################### RTN Algo Entry ##################################
Expand Down Expand Up @@ -160,6 +161,25 @@ def static_quant_entry(
return model


###################### PT2E Static Quant Algo Entry ##################################
@register_algo(name=PT2E_STATIC_QUANT)
@torch.no_grad()
def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
logger.info("Quantize model with the PT2E static quant algorithm.")
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8StaticQuantizer

run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
inplace = kwargs.get("inplace", True)
for _, quant_config in configs_mapping.items():
if quant_config.name == PT2E_STATIC_QUANT:
w8a8_quantizer = W8A8StaticQuantizer(quant_config=quant_config)
model = w8a8_quantizer.execute(
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
)
return model


###################### Smooth Quant Algo Entry ##################################
@register_algo(name=SMOOTH_QUANT)
@torch.no_grad()
Expand Down
83 changes: 82 additions & 1 deletion neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint:disable=import-error

from collections import OrderedDict
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -48,7 +48,9 @@
PRIORITY_HQQ,
PRIORITY_RTN,
PRIORITY_TEQ,
PT2E_STATIC_QUANT,
)
from neural_compressor.torch.utils.utility import _ConfigMappingWrapper

__all__ = [
"RTNConfig",
Expand All @@ -57,6 +59,8 @@
"get_default_gptq_config",
"HQQConfig",
"get_default_hqq_config",
"PT2EStaticQuantConfig",
"get_default_pt2e_static_config",
]


Expand Down Expand Up @@ -775,6 +779,83 @@ def get_default_AutoRound_config() -> AutoRoundConfig:
return AutoRoundConfig()


######################## PT2E Static Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=PT2E_STATIC_QUANT)
class PT2EStaticQuantConfig(BaseConfig):
"""Config class for PT2E static quantization."""

name = PT2E_STATIC_QUANT
params_list = [
"w_dtype",
"w_sym",
"w_granularity",
"w_algo",
"act_dtype",
"act_sym",
"act_granularity",
"act_algo",
]
supported_configs: List[OperatorConfig] = []

def __init__(
self,
w_dtype: str = "int8",
w_sym: bool = True,
w_granularity: str = "per_channel",
w_algo: str = "minmax",
act_dtype: str = "uint8",
act_sym: bool = False,
act_granularity: str = "per_tensor",
act_algo: str = "kl",
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init PT2E Static Quant Configs."""
super().__init__(white_list=white_list)
self.w_dtype = w_dtype
self.w_sym = w_sym
self.w_granularity = w_granularity
self.w_algo = w_algo
self.act_dtype = act_dtype
self.act_sym = act_sym
self.act_granularity = act_granularity
self.act_algo = act_algo
# used by export to specific dynamic shapes of example inputs
self.dynamic_shapes = dynamic_shapes
self._post_init()

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
linear_static_config = cls()
operators = [torch.nn.Linear]
supported_configs.append(OperatorConfig(config=linear_static_config, operators=operators))
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]:
pass

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "PT2EStaticQuantConfig", List["PT2EStaticQuantConfig"]]:
return cls(act_sym=[True, False], act_algo=["kl", "minmax"])

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDict[Union[str, str], OrderedDict[str, BaseConfig]]:
config_mapping = OrderedDict({self.name: self})
return config_mapping


def get_default_pt2e_static_config() -> PT2EStaticQuantConfig:
"""Generate the default pt2e static quant config.

Returns:
the default pt2e static quant config.
"""
return PT2EStaticQuantConfig()


######################## Static Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT)
class StaticQuantConfig(BaseConfig):
Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/torch/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@
PRIORITY_AWQ = 70
PRIORITY_TEQ = 60
PRIORITY_AUTOROUND = 50


PT2E_STATIC_QUANT = "pt2e_static_quant"
15 changes: 15 additions & 0 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.


from collections import OrderedDict
from enum import Enum
from typing import Callable, Dict, List, Tuple, Union

Expand Down Expand Up @@ -131,3 +132,17 @@ class Mode(Enum):
PREPARE = "prepare"
CONVERT = "convert"
QUANTIZE = "quantize"


class _ConfigMappingWrapper(OrderedDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._orig_config = None

@property
def orig_config(self):
return self.orig_config

@orig_config.setter
def orig_config(self, value):
self.orig_config = value
5 changes: 2 additions & 3 deletions test/3x/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
def test_quantizer_on_simple_model(self):
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
quant_config = None
w8a8_static_quantizer = W8A8StaticQuantizer()
# prepare
prepare_model = w8a8_static_quantizer.prepare(model, quant_config, example_inputs=example_inputs)
prepare_model = w8a8_static_quantizer.prepare(model, example_inputs=example_inputs)
# calibrate
for i in range(2):
prepare_model(*example_inputs)
Expand All @@ -80,7 +79,7 @@ def test_quantizer_on_llm(self):
quant_config = None
w8a8_static_quantizer = W8A8StaticQuantizer()
# prepare
prepare_model = w8a8_static_quantizer.prepare(model, quant_config, example_inputs=example_inputs)
prepare_model = w8a8_static_quantizer.prepare(model, example_inputs=example_inputs)
# calibrate
for i in range(2):
prepare_model(*example_inputs)
Expand Down
127 changes: 127 additions & 0 deletions test/3x/torch/quantization/test_pt2e_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os
import unittest
from unittest.mock import patch

import pytest
import torch

from neural_compressor.common.utils import logger
from neural_compressor.torch.quantization import (
PT2EStaticQuantConfig,
convert,
get_default_pt2e_static_config,
prepare,
quantize,
)
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version


class TestPT2EQuantization:

@staticmethod
def get_toy_model():
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b

inp1 = torch.randn(10)
inp2 = torch.randn(10)
example_inputs = (inp1, inp2)
bar = Bar()
return bar, example_inputs

@staticmethod
def build_simple_torch_model_and_example_inputs():
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 20)
self.fc2 = torch.nn.Linear(20, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = torch.nn.functional.relu(x)
x = self.fc2(x)
return x

model = SimpleModel()
example_inputs = (torch.randn(10, 10),)
return model, example_inputs

@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
def test_quantize_simple_model(self):
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
quant_config = None

def calib_fn(model):
for i in range(2):
model(*example_inputs)

quant_config = get_default_pt2e_static_config()
q_model = quantize(model=model, quant_config=quant_config, example_inputs=example_inputs, run_fn=calib_fn)
from torch._inductor import config

config.freezing = True
opt_model = torch.compile(q_model)
out = opt_model(*example_inputs)
logger.warning("out shape is %s", out.shape)
assert out is not None

@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
def test_prepare_and_convert_on_simple_model(self):
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
quant_config = None

def calib_fn(model):
for i in range(2):
model(*example_inputs)

quant_config = get_default_pt2e_static_config()

prepared_model = prepare(model, quant_config=quant_config, example_inputs=example_inputs)
calib_fn(prepared_model)
q_model = convert(prepared_model)
assert q_model is not None, "Quantization failed!"

from torch._inductor import config

config.freezing = True
opt_model = torch.compile(q_model)
out = opt_model(*example_inputs)
logger.warning("out shape is %s", out.shape)
assert out is not None

@pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0")
def test_prepare_and_convert_on_llm(self):
from transformers import AutoModelForCausalLM, AutoTokenizer

# set TOKENIZERS_PARALLELISM to false

os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_name = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
example_inputs = (input_ids,)
quant_config = get_default_pt2e_static_config()
# prepare
prepare_model = prepare(model, quant_config, example_inputs=example_inputs)
# calibrate
for i in range(2):
prepare_model(*example_inputs)
# convert
converted_model = convert(prepare_model)
# inference
from torch._inductor import config

config.freezing = True
opt_model = torch.compile(converted_model)
out = opt_model(*example_inputs)
assert out.logits is not None
Loading