From ff231fb55e33c37126a0ef7f0e739cd750d1ef6c Mon Sep 17 00:00:00 2001 From: XiaobingZhang Date: Mon, 23 May 2022 11:08:03 +0800 Subject: [PATCH] quantization: support dynamic linear and lstm (#787) --- .../ao/quantization/README.md | 43 +++++++++- .../ao/quantization/_module_swap_utils.py | 78 +++++++++++++++++++ .../ao/quantization/_quantization_state.py | 14 +++- .../quantization/_quantization_state_utils.py | 14 +++- .../ao/quantization/_quantize_utils.py | 5 +- .../csrc/jit/fusion_pass.cpp | 6 +- tests/cpu/test_ao_jit_ipex_quantization.py | 41 +++++++++- 7 files changed, 186 insertions(+), 15 deletions(-) create mode 100644 intel_extension_for_pytorch/ao/quantization/_module_swap_utils.py diff --git a/intel_extension_for_pytorch/ao/quantization/README.md b/intel_extension_for_pytorch/ao/quantization/README.md index 5496cf055..f3c34f5a4 100644 --- a/intel_extension_for_pytorch/ao/quantization/README.md +++ b/intel_extension_for_pytorch/ao/quantization/README.md @@ -46,7 +46,7 @@ for data in calibration_data_set: # prepared_model.load_qconf_summary(qconf_summary = "configure.json") ``` -### Convert to Quantized Model and Deploy +### Convert to Static Quantized Model and Deploy ```python # make sure the example_inputs's size is same as the real input's size @@ -63,9 +63,46 @@ y = traced_model(x) # quantized_model = torch.jit.load("quantized_model.pt") # quantized_model = torch.jit.freeze(quantized_model.eval()) # ... - ``` ## Dynamic Quantization -TODO(future PR): +```python +import intel_extension_for_pytorch as ipex +from intel_extension_for_pytorch.quantization import prepare, convert +``` + +### Define QConfig + +```python +from torch.ao.quantization import MinMaxObserver, PlaceholderObserver, QConfig +dynamic_qconfig = QConfig( + activation = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8), + weight = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) +``` + +Note: For weight observer, it only support dtype **torch.qint8**, and the qscheme can be **torch.per_tensor_symmetric** or **torch.per_tensor_symmetric**. + +### Prepare Model + +```python +prepared_model = prepare(user_model, qconfig, example_inputs=example_inputs, inplace=False) +``` + +## Convert to Dynamic Quantized Model and Deploy + +```python +# make sure the example_inputs's size is same as the real input's size +convert_model = convert(prepared_model) +with torch.no_grad(): + traced_model = torch.jit.trace(convert_model, example_input) + traced_model = torch.jit.freeze(traced_model) +# for inference +y = traced_model(x) + +# or save the model to deploy +# traced_model.save("quantized_model.pt") +# quantized_model = torch.jit.load("quantized_model.pt") +# quantized_model = torch.jit.freeze(quantized_model.eval()) +# ... +``` diff --git a/intel_extension_for_pytorch/ao/quantization/_module_swap_utils.py b/intel_extension_for_pytorch/ao/quantization/_module_swap_utils.py new file mode 100644 index 000000000..b32b98281 --- /dev/null +++ b/intel_extension_for_pytorch/ao/quantization/_module_swap_utils.py @@ -0,0 +1,78 @@ +from typing import Dict, Callable, Any, Optional + +import torch +import torch.nn as nn + +from torch.ao.quantization import swap_module +import torch.nn.quantized.dynamic as nnqd + + +# Default map for swapping dynamic modules +DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { + nn.Linear: nnqd.Linear, + nn.LSTM: nnqd.LSTM, + # TODO: support more RNN module + #nn.GRUCell: nnqd.GRUCell, + #nn.GRU: nnqd.GRU, + #nn.LSTMCell: nnqd.LSTMCell, + #nn.RNNCell: nnqd.RNNCell, +} + +def _get_qconfig_dtypes(qconfig): + r""" + Returns the qconfig tuple for qconfig: + (activation_dtype, weight_dtype, activation_compute_dtype) + """ + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None + return (activation.dtype, weight.dtype, compute_dtype) + +def _op_is_int8_dynamically_quantized(qconfig) -> bool: + r""" + Given a qconfig, returns True if this op is using int8 dynamic + quantization + """ + activation_dtype, weight_dtype, activation_compute_dtype = \ + _get_qconfig_dtypes(qconfig) + return ( + activation_dtype is torch.float and + # for now, the lines below assume fbgemm or qnnpack + weight_dtype is torch.qint8 and + activation_compute_dtype is torch.quint8 + ) + + +def swap_child_modules( + module: torch.nn.Module, + dynamic_mappings: Dict[Callable, Any] = DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, + parent_fqn: Optional[str] = None, +) -> None: + """ + For each direct child of `module`, swaps it using `dyanamic_mappings` + if the qconfig for that child is using int8 dynamic quantization, + and the module type is in the mapping. + Recursively calls itself on each child. + """ + + if hasattr(module, '_auto_quant_state'): + qstate = module._auto_quant_state + for _, qopinfo in qstate.idx_to_seen_q_op_infos.items(): + qconfig = qopinfo.qconfig + if not qconfig: + continue + fqn = qopinfo.fqn + if not fqn: + continue + op_int8_dynamically_quantized = _op_is_int8_dynamically_quantized(qconfig) + + if op_int8_dynamically_quantized: + mod = module._modules[fqn] + if not type(mod) in dynamic_mappings: + continue + mod.qconfig = qconfig + module._modules[fqn] = swap_module(mod, dynamic_mappings, {}) + + for _, child in module.named_children(): + swap_child_modules(child) diff --git a/intel_extension_for_pytorch/ao/quantization/_quantization_state.py b/intel_extension_for_pytorch/ao/quantization/_quantization_state.py index 7a50433d4..486322229 100644 --- a/intel_extension_for_pytorch/ao/quantization/_quantization_state.py +++ b/intel_extension_for_pytorch/ao/quantization/_quantization_state.py @@ -416,6 +416,10 @@ def op_weight_convert_before_hook( if op.bias: new_args.append(weights[tensor_arg_idx + 2]) new_args.append(weights[tensor_arg_idx + 3]) + else: + for s in range(step): + new_args.append(weights[tensor_arg_idx + s]) + return new_args def op_convert_after_hook( @@ -713,7 +717,8 @@ def _maybe_insert_input_observers(self, seen_q_op_info: SeenQOpInfo): # always add observer if the op can be quantized. tensor_id = tensor_info.id # type: ignore[attr-defined] weight_arg_idx = get_weight_arg_idx(seen_q_op_info.type) - if idx == weight_arg_idx: + # avoid add weight observer for dynamic quantization. + if idx == weight_arg_idx and not isinstance(qconfig.activation(), torch.ao.quantization.PlaceholderObserver): # conv_transpose weight is iohw or iodhw, so we change the observer axis to 1. if seen_q_op_info.type in [str(F.conv_transpose2d), str(F.conv_transpose3d)] and \ isinstance(qconfig.weight(), torch.ao.quantization.PerChannelMinMaxObserver): @@ -736,17 +741,18 @@ def _maybe_insert_input_observers(self, seen_q_op_info: SeenQOpInfo): tensor_id = tensor_info.id # type: ignore[attr-defined] if seen_q_op_info.type == str(torch.nn.EmbeddingBag): obs = qconfig.activation() - else: + self.weight_tensor_id_to_observer[str(seen_q_op_info.idx) + "_" + str(tensor_id)] = obs + elif not isinstance(qconfig.activation(), torch.ao.quantization.PlaceholderObserver): if seen_q_op_info.type in [str(torch.nn.ConvTranspose2d), str(torch.nn.ConvTranspose3d)] and \ isinstance(qconfig.weight(), torch.ao.quantization.PerChannelMinMaxObserver): obs = qconfig.weight.with_args(ch_axis=1)() else: obs = qconfig.weight() - self.weight_tensor_id_to_observer[str(seen_q_op_info.idx) + "_" + str(tensor_id)] = obs + self.weight_tensor_id_to_observer[str(seen_q_op_info.idx) + "_" + str(tensor_id)] = obs # LSTM, we don't know whether has bais or not, so we add observer for all them, but will not use them at convert step. # w_ih, w_hh share same observe, and b_ih, b_hh also share same observer if seen_q_op_info.type == str(torch.nn.LSTM): - if qconfig is not None: + if qconfig is not None and not isinstance(qconfig.activation(), torch.ao.quantization.PlaceholderObserver): for i in range(0, len(seen_q_op_info.weight_tensor_infos), 2): tensor_id = seen_q_op_info.weight_tensor_infos[i].id obs = qconfig.weight() diff --git a/intel_extension_for_pytorch/ao/quantization/_quantization_state_utils.py b/intel_extension_for_pytorch/ao/quantization/_quantization_state_utils.py index dd387414c..69fddc50c 100644 --- a/intel_extension_for_pytorch/ao/quantization/_quantization_state_utils.py +++ b/intel_extension_for_pytorch/ao/quantization/_quantization_state_utils.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch.nn.quantized.dynamic as nnqd from intel_extension_for_pytorch.nn.functional import interaction import intel_extension_for_pytorch._C as core @@ -63,14 +64,23 @@ torch.nn.EmbeddingBag, torch.nn.Flatten, torch.nn.LSTM, + # dynamic quantization module + nnqd.Linear, + nnqd.LSTM, ]) may_inplace_module = set([ torch.nn.ReLU, ]) -binary_related_ops = ( + +a_related_to_b = ( (str(torch.add), str(torch.Tensor.add)), + (str(torch.Tensor.add), str(torch.add)), + (str(torch.nn.Linear), str(nnqd.Linear)), + (str(nnqd.Linear), str(torch.nn.Linear)), + (str(torch.nn.LSTM), str(nnqd.LSTM)), + (str(nnqd.LSTM), str(torch.nn.LSTM)), ) conv_linear_ops = [ @@ -123,7 +133,7 @@ def ops_are_related( if type_is_module: cur_op = type(cur_op) return str(cur_op) == expected_op_type or \ - (str(cur_op), expected_op_type) in binary_related_ops + (str(cur_op), expected_op_type) in a_related_to_b def _raise_obs_not_found_error(func): raise RuntimeError( diff --git a/intel_extension_for_pytorch/ao/quantization/_quantize_utils.py b/intel_extension_for_pytorch/ao/quantization/_quantize_utils.py index 85a54c77c..671236b5b 100644 --- a/intel_extension_for_pytorch/ao/quantization/_quantize_utils.py +++ b/intel_extension_for_pytorch/ao/quantization/_quantize_utils.py @@ -10,7 +10,7 @@ sync_pool_input_output_scale_zp, module_call_to_function_call, quantized_modules_has_weights, load_qconf_summary_to_model from ._quantization_state import AutoQuantizationState, AutoQuantizationStateModuleDict, init_model_quant_state from ._recipe import get_defaut_recipe - +from ._module_swap_utils import swap_child_modules # AutoQuantizationState lives in parent module's _modules. # Currently, `torch.nn.Sequential`'s forward iterates over all @@ -540,7 +540,8 @@ def unwrap_proxy(a): for _, v in module._fqn_to_auto_quant_state_map.items(): v.tensor_id_to_observer.clear() v.weight_tensor_id_to_observer.clear() - # Attach quan_info to parent each module + # Attach quant_info to parent each module attach_op_convert_info_to_model(module) + swap_child_modules(module) module.__class__ = QuantizationDispatchModule return module diff --git a/intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp b/intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp index e2c0e27df..a24704a1b 100644 --- a/intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp +++ b/intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp @@ -455,7 +455,9 @@ bool checkQuantization(Block* block) { if (node->kind() == Symbol::aten("quantize_per_tensor") || node->kind() == Symbol::aten("dequantize") || - node->kind() == Symbol::aten("quantize_per_channel")) { + node->kind() == Symbol::aten("quantize_per_channel") || + node->kind() == Symbol::aten("quantized_lstm") || + node->kind() == Symbol::fromQualString("quantized::linear_dynamic")) { return true; } } @@ -476,11 +478,11 @@ void FusionPass(std::shared_ptr& graph) { // remove BailOut and BailoutTemplate RemoveBailOutNodesAndSpecializeTypes(graph->block()); RemoveBailoutTemplateNodes(graph->block()); - // LLGA fusion pass for int8 GRAPH_DUMP( "After RemoveProfileNodesAndSpecializeTypes. Before LLGA fusion pass", graph); + if (isQuantized(graph) || torch_ipex::autocast::is_llga_fp32_bf16_enabled()) { RemoveRedundantAliases(graph); fuser::onednn::fuseGraph(graph); diff --git a/tests/cpu/test_ao_jit_ipex_quantization.py b/tests/cpu/test_ao_jit_ipex_quantization.py index 88220317b..9de5382e6 100644 --- a/tests/cpu/test_ao_jit_ipex_quantization.py +++ b/tests/cpu/test_ao_jit_ipex_quantization.py @@ -15,7 +15,8 @@ from torch.testing._internal.common_utils import TEST_SCIPY, TemporaryFileName import intel_extension_for_pytorch as ipex -from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, HistogramObserver, QConfig +from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, HistogramObserver, \ + QConfig, PlaceholderObserver default_weight_observer = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric) @@ -34,6 +35,9 @@ weight = default_weight_observer), ] +dynamic_qconfig = QConfig( + activation = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8), + weight = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) class TestIpexOps(JitLlgaTestCase): def test_adaptive_avg_pool2d(self): @@ -304,6 +308,39 @@ def forward(self, x): graph, _, _ = self.prepareModel(m, [x]) FileCheck().check_not("aten::mul_").check("aten::mul").run(graph) +class TestDynamicQuantization(JitLlgaTestCase): + def test_linear_dynamic(self): + class M(nn.Module): + def __init__(self): + super(M, self).__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + x = self.linear(x) + return x + + m = M().eval() + x = torch.randn(1, 3) + graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=dynamic_qconfig) + FileCheck().check_not("aten:linear").check("quantized::linear_dynamic").run(graph) + + def test_lstm_dynamic(self): + class M(nn.Module): + def __init__(self): + super(M, self).__init__() + self.lstm = torch.nn.LSTM(10, 20, 2) + + def forward(self, x, hx, cx): + x, h_xs = self.lstm(x, (hx, cx)) + return x, h_xs + + m = M().eval() + x = torch.randn(5, 3, 10) + h = torch.randn(2, 3, 20) + c = torch.randn(2, 3, 20) + graph = self.checkQuantizeTrace(m, [x, h, c], atol=2e-1, qconfig=dynamic_qconfig) + FileCheck().check_not("aten:lstm").check("aten::quantized_lstm").run(graph) + if __name__ == '__main__': - run_tests() \ No newline at end of file + run_tests()