Skip to content

Commit

Permalink
quantization: support dynamic linear and lstm (#787)
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaobingSuper authored May 23, 2022
1 parent 940f189 commit ff231fb
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 15 deletions.
43 changes: 40 additions & 3 deletions intel_extension_for_pytorch/ao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ for data in calibration_data_set:
# prepared_model.load_qconf_summary(qconf_summary = "configure.json")
```

### Convert to Quantized Model and Deploy
### Convert to Static Quantized Model and Deploy

```python
# make sure the example_inputs's size is same as the real input's size
Expand All @@ -63,9 +63,46 @@ y = traced_model(x)
# quantized_model = torch.jit.load("quantized_model.pt")
# quantized_model = torch.jit.freeze(quantized_model.eval())
# ...

```

## Dynamic Quantization

TODO(future PR):
```python
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.quantization import prepare, convert
```

### Define QConfig

```python
from torch.ao.quantization import MinMaxObserver, PlaceholderObserver, QConfig
dynamic_qconfig = QConfig(
activation = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8),
weight = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
```

Note: For weight observer, it only support dtype **torch.qint8**, and the qscheme can be **torch.per_tensor_symmetric** or **torch.per_tensor_symmetric**.

### Prepare Model

```python
prepared_model = prepare(user_model, qconfig, example_inputs=example_inputs, inplace=False)
```

## Convert to Dynamic Quantized Model and Deploy

```python
# make sure the example_inputs's size is same as the real input's size
convert_model = convert(prepared_model)
with torch.no_grad():
traced_model = torch.jit.trace(convert_model, example_input)
traced_model = torch.jit.freeze(traced_model)
# for inference
y = traced_model(x)

# or save the model to deploy
# traced_model.save("quantized_model.pt")
# quantized_model = torch.jit.load("quantized_model.pt")
# quantized_model = torch.jit.freeze(quantized_model.eval())
# ...
```
78 changes: 78 additions & 0 deletions intel_extension_for_pytorch/ao/quantization/_module_swap_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Dict, Callable, Any, Optional

import torch
import torch.nn as nn

from torch.ao.quantization import swap_module
import torch.nn.quantized.dynamic as nnqd


# Default map for swapping dynamic modules
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
nn.Linear: nnqd.Linear,
nn.LSTM: nnqd.LSTM,
# TODO: support more RNN module
#nn.GRUCell: nnqd.GRUCell,
#nn.GRU: nnqd.GRU,
#nn.LSTMCell: nnqd.LSTMCell,
#nn.RNNCell: nnqd.RNNCell,
}

def _get_qconfig_dtypes(qconfig):
r"""
Returns the qconfig tuple for qconfig:
(activation_dtype, weight_dtype, activation_compute_dtype)
"""
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None
return (activation.dtype, weight.dtype, compute_dtype)

def _op_is_int8_dynamically_quantized(qconfig) -> bool:
r"""
Given a qconfig, returns True if this op is using int8 dynamic
quantization
"""
activation_dtype, weight_dtype, activation_compute_dtype = \
_get_qconfig_dtypes(qconfig)
return (
activation_dtype is torch.float and
# for now, the lines below assume fbgemm or qnnpack
weight_dtype is torch.qint8 and
activation_compute_dtype is torch.quint8
)


def swap_child_modules(
module: torch.nn.Module,
dynamic_mappings: Dict[Callable, Any] = DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
parent_fqn: Optional[str] = None,
) -> None:
"""
For each direct child of `module`, swaps it using `dyanamic_mappings`
if the qconfig for that child is using int8 dynamic quantization,
and the module type is in the mapping.
Recursively calls itself on each child.
"""

if hasattr(module, '_auto_quant_state'):
qstate = module._auto_quant_state
for _, qopinfo in qstate.idx_to_seen_q_op_infos.items():
qconfig = qopinfo.qconfig
if not qconfig:
continue
fqn = qopinfo.fqn
if not fqn:
continue
op_int8_dynamically_quantized = _op_is_int8_dynamically_quantized(qconfig)

if op_int8_dynamically_quantized:
mod = module._modules[fqn]
if not type(mod) in dynamic_mappings:
continue
mod.qconfig = qconfig
module._modules[fqn] = swap_module(mod, dynamic_mappings, {})

for _, child in module.named_children():
swap_child_modules(child)
14 changes: 10 additions & 4 deletions intel_extension_for_pytorch/ao/quantization/_quantization_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,10 @@ def op_weight_convert_before_hook(
if op.bias:
new_args.append(weights[tensor_arg_idx + 2])
new_args.append(weights[tensor_arg_idx + 3])
else:
for s in range(step):
new_args.append(weights[tensor_arg_idx + s])

return new_args

def op_convert_after_hook(
Expand Down Expand Up @@ -713,7 +717,8 @@ def _maybe_insert_input_observers(self, seen_q_op_info: SeenQOpInfo):
# always add observer if the op can be quantized.
tensor_id = tensor_info.id # type: ignore[attr-defined]
weight_arg_idx = get_weight_arg_idx(seen_q_op_info.type)
if idx == weight_arg_idx:
# avoid add weight observer for dynamic quantization.
if idx == weight_arg_idx and not isinstance(qconfig.activation(), torch.ao.quantization.PlaceholderObserver):
# conv_transpose weight is iohw or iodhw, so we change the observer axis to 1.
if seen_q_op_info.type in [str(F.conv_transpose2d), str(F.conv_transpose3d)] and \
isinstance(qconfig.weight(), torch.ao.quantization.PerChannelMinMaxObserver):
Expand All @@ -736,17 +741,18 @@ def _maybe_insert_input_observers(self, seen_q_op_info: SeenQOpInfo):
tensor_id = tensor_info.id # type: ignore[attr-defined]
if seen_q_op_info.type == str(torch.nn.EmbeddingBag):
obs = qconfig.activation()
else:
self.weight_tensor_id_to_observer[str(seen_q_op_info.idx) + "_" + str(tensor_id)] = obs
elif not isinstance(qconfig.activation(), torch.ao.quantization.PlaceholderObserver):
if seen_q_op_info.type in [str(torch.nn.ConvTranspose2d), str(torch.nn.ConvTranspose3d)] and \
isinstance(qconfig.weight(), torch.ao.quantization.PerChannelMinMaxObserver):
obs = qconfig.weight.with_args(ch_axis=1)()
else:
obs = qconfig.weight()
self.weight_tensor_id_to_observer[str(seen_q_op_info.idx) + "_" + str(tensor_id)] = obs
self.weight_tensor_id_to_observer[str(seen_q_op_info.idx) + "_" + str(tensor_id)] = obs
# LSTM, we don't know whether has bais or not, so we add observer for all them, but will not use them at convert step.
# w_ih, w_hh share same observe, and b_ih, b_hh also share same observer
if seen_q_op_info.type == str(torch.nn.LSTM):
if qconfig is not None:
if qconfig is not None and not isinstance(qconfig.activation(), torch.ao.quantization.PlaceholderObserver):
for i in range(0, len(seen_q_op_info.weight_tensor_infos), 2):
tensor_id = seen_q_op_info.weight_tensor_infos[i].id
obs = qconfig.weight()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.quantized.dynamic as nnqd
from intel_extension_for_pytorch.nn.functional import interaction
import intel_extension_for_pytorch._C as core

Expand Down Expand Up @@ -63,14 +64,23 @@
torch.nn.EmbeddingBag,
torch.nn.Flatten,
torch.nn.LSTM,
# dynamic quantization module
nnqd.Linear,
nnqd.LSTM,
])

may_inplace_module = set([
torch.nn.ReLU,
])

binary_related_ops = (

a_related_to_b = (
(str(torch.add), str(torch.Tensor.add)),
(str(torch.Tensor.add), str(torch.add)),
(str(torch.nn.Linear), str(nnqd.Linear)),
(str(nnqd.Linear), str(torch.nn.Linear)),
(str(torch.nn.LSTM), str(nnqd.LSTM)),
(str(nnqd.LSTM), str(torch.nn.LSTM)),
)

conv_linear_ops = [
Expand Down Expand Up @@ -123,7 +133,7 @@ def ops_are_related(
if type_is_module:
cur_op = type(cur_op)
return str(cur_op) == expected_op_type or \
(str(cur_op), expected_op_type) in binary_related_ops
(str(cur_op), expected_op_type) in a_related_to_b

def _raise_obs_not_found_error(func):
raise RuntimeError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
sync_pool_input_output_scale_zp, module_call_to_function_call, quantized_modules_has_weights, load_qconf_summary_to_model
from ._quantization_state import AutoQuantizationState, AutoQuantizationStateModuleDict, init_model_quant_state
from ._recipe import get_defaut_recipe

from ._module_swap_utils import swap_child_modules

# AutoQuantizationState lives in parent module's _modules.
# Currently, `torch.nn.Sequential`'s forward iterates over all
Expand Down Expand Up @@ -540,7 +540,8 @@ def unwrap_proxy(a):
for _, v in module._fqn_to_auto_quant_state_map.items():
v.tensor_id_to_observer.clear()
v.weight_tensor_id_to_observer.clear()
# Attach quan_info to parent each module
# Attach quant_info to parent each module
attach_op_convert_info_to_model(module)
swap_child_modules(module)
module.__class__ = QuantizationDispatchModule
return module
6 changes: 4 additions & 2 deletions intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,9 @@ bool checkQuantization(Block* block) {

if (node->kind() == Symbol::aten("quantize_per_tensor") ||
node->kind() == Symbol::aten("dequantize") ||
node->kind() == Symbol::aten("quantize_per_channel")) {
node->kind() == Symbol::aten("quantize_per_channel") ||
node->kind() == Symbol::aten("quantized_lstm") ||
node->kind() == Symbol::fromQualString("quantized::linear_dynamic")) {
return true;
}
}
Expand All @@ -476,11 +478,11 @@ void FusionPass(std::shared_ptr<Graph>& graph) {
// remove BailOut and BailoutTemplate
RemoveBailOutNodesAndSpecializeTypes(graph->block());
RemoveBailoutTemplateNodes(graph->block());

// LLGA fusion pass for int8
GRAPH_DUMP(
"After RemoveProfileNodesAndSpecializeTypes. Before LLGA fusion pass",
graph);

if (isQuantized(graph) || torch_ipex::autocast::is_llga_fp32_bf16_enabled()) {
RemoveRedundantAliases(graph);
fuser::onednn::fuseGraph(graph);
Expand Down
41 changes: 39 additions & 2 deletions tests/cpu/test_ao_jit_ipex_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from torch.testing._internal.common_utils import TEST_SCIPY, TemporaryFileName

import intel_extension_for_pytorch as ipex
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, HistogramObserver, QConfig
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, HistogramObserver, \
QConfig, PlaceholderObserver

default_weight_observer = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)

Expand All @@ -34,6 +35,9 @@
weight = default_weight_observer),
]

dynamic_qconfig = QConfig(
activation = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8),
weight = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))

class TestIpexOps(JitLlgaTestCase):
def test_adaptive_avg_pool2d(self):
Expand Down Expand Up @@ -304,6 +308,39 @@ def forward(self, x):
graph, _, _ = self.prepareModel(m, [x])
FileCheck().check_not("aten::mul_").check("aten::mul").run(graph)

class TestDynamicQuantization(JitLlgaTestCase):
def test_linear_dynamic(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
x = self.linear(x)
return x

m = M().eval()
x = torch.randn(1, 3)
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=dynamic_qconfig)
FileCheck().check_not("aten:linear").check("quantized::linear_dynamic").run(graph)

def test_lstm_dynamic(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.lstm = torch.nn.LSTM(10, 20, 2)

def forward(self, x, hx, cx):
x, h_xs = self.lstm(x, (hx, cx))
return x, h_xs

m = M().eval()
x = torch.randn(5, 3, 10)
h = torch.randn(2, 3, 20)
c = torch.randn(2, 3, 20)
graph = self.checkQuantizeTrace(m, [x, h, c], atol=2e-1, qconfig=dynamic_qconfig)
FileCheck().check_not("aten:lstm").check("aten::quantized_lstm").run(graph)


if __name__ == '__main__':
run_tests()
run_tests()

0 comments on commit ff231fb

Please sign in to comment.