Skip to content

Commit

Permalink
Migrate static quant ipex backend to 3.x API (#1596)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com>
Signed-off-by: chensuyue <suyue.chen@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
  • Loading branch information
violetch24 authored Feb 7, 2024
1 parent 07f940c commit 191383e
Show file tree
Hide file tree
Showing 14 changed files with 1,015 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .azure-pipelines/scripts/codeScan/pylint/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ apt-get install -y --no-install-recommends --fix-missing \
pip install -r /neural-compressor/requirements.txt
pip install cmake

pip install torch==1.12.0 \
pip install torch \
horovod \
google \
autograd \
Expand Down
26 changes: 15 additions & 11 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,12 +397,12 @@ def _cfgs_to_fx_cfgs(op_cfgs, observer_type="post_training_static_quant"):
for key, value in op_cfgs.items():
if key == "default_qconfig":
if version.release >= Version("1.13.0").release: # pragma: no cover
fx_op_cfgs.set_global(value)
fx_op_cfgs.set_global(value) # pylint: disable=E1101
else:
fx_op_cfgs[""] = value
continue
if version.release >= Version("1.13.0").release: # pragma: no cover
fx_op_cfgs.set_module_name(key, value)
fx_op_cfgs.set_module_name(key, value) # pylint: disable=E1101
else:
op_tuple = (key, value)
op_tuple_cfg_list.append(op_tuple)
Expand All @@ -413,7 +413,7 @@ def _cfgs_to_fx_cfgs(op_cfgs, observer_type="post_training_static_quant"):
from torch.ao.quantization import get_default_qconfig_mapping

for name, q_config in get_default_qconfig_mapping().to_dict()["object_type"]:
fx_op_cfgs.set_object_type(name, q_config)
fx_op_cfgs.set_object_type(name, q_config) # pylint: disable=E1101

return fx_op_cfgs

Expand Down Expand Up @@ -3619,7 +3619,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
prepare_custom_config=self.prepare_custom_config_dict,
)
else:
q_model._model = prepare_qat_fx(
q_model._model = prepare_qat_fx( # pylint: disable=E1120,E1123
q_model._model, self.fx_op_cfgs, prepare_custom_config_dict=self.prepare_custom_config_dict
)
else:
Expand Down Expand Up @@ -3651,7 +3651,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
prepare_custom_config=self.prepare_custom_config_dict,
)
else:
q_model._model = prepare_fx(
q_model._model = prepare_fx( # pylint: disable=E1120,E1123
q_model._model, self.fx_op_cfgs, prepare_custom_config_dict=self.prepare_custom_config_dict
)
else:
Expand Down Expand Up @@ -3681,7 +3681,9 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
# pylint: disable=E1123
q_model._model = convert_fx(q_model._model, convert_custom_config=self.convert_custom_config_dict)
else:
q_model._model = convert_fx(q_model._model, convert_custom_config_dict=self.convert_custom_config_dict)
q_model._model = convert_fx( # pylint: disable=E1123
q_model._model, convert_custom_config_dict=self.convert_custom_config_dict
)
torch_utils.util.append_attr(q_model._model, tmp_model)
del tmp_model
gc.collect()
Expand Down Expand Up @@ -3830,7 +3832,7 @@ def _pre_hook_for_qat(self, dataloader=None):
),
)
else:
self.model._model = prepare_qat_fx(
self.model._model = prepare_qat_fx( # pylint: disable=E1120,E1123
self.model._model,
fx_op_cfgs,
prepare_custom_config_dict=(
Expand Down Expand Up @@ -3877,7 +3879,7 @@ def _post_hook_for_qat(self):
),
)
else:
self.model._model = convert_fx(
self.model._model = convert_fx( # pylint: disable=E1123
self.model._model,
convert_custom_config_dict=(
self.model.kwargs.get("convert_custom_config_dict", None)
Expand Down Expand Up @@ -4331,15 +4333,15 @@ def prepare_sub_graph(
# pragma: no cover
if is_qat:
module_pre = (
prepare_qat_fx(tmp_module, fx_sub_op_cfgs)
prepare_qat_fx(tmp_module, fx_sub_op_cfgs) # pylint: disable=E1120
if version <= Version("1.12.1")
else prepare_qat_fx(tmp_module, fx_sub_op_cfgs, example_inputs=example_inputs)
)
# pylint: disable=E1123
# pragma: no cover
else:
module_pre = (
prepare_fx(tmp_module, fx_sub_op_cfgs)
prepare_fx(tmp_module, fx_sub_op_cfgs) # pylint: disable=E1120
if version <= Version("1.12.1")
else prepare_fx(tmp_module, fx_sub_op_cfgs, example_inputs=example_inputs)
)
Expand Down Expand Up @@ -4433,7 +4435,9 @@ def fuse_fx_model(self, model, is_qat):
fused_model = _fuse_fx(graph_module, is_qat, fuse_custom_config=prepare_custom_config_dict)
elif self.version.release >= Version("1.11.0").release: # pragma: no cover
# pylint: disable=E1124
fused_model = _fuse_fx(graph_module, is_qat, fuse_custom_config_dict=prepare_custom_config_dict)
fused_model = _fuse_fx( # pylint: disable=E1123
graph_module, is_qat, fuse_custom_config_dict=prepare_custom_config_dict
)
else:
fused_model = _fuse_fx(graph_module, prepare_custom_config_dict)
except:
Expand Down
6 changes: 3 additions & 3 deletions neural_compressor/adaptor/torch_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ def output_hook(self, input, output):
if adaptor.version.release >= Version("1.13.0").release: # pragma: no cover
tmp_model = prepare_fx(tmp_model, fx_op_cfgs, example_inp)
else:
tmp_model = prepare_fx(
tmp_model = prepare_fx( # pylint: disable=E1120
tmp_model,
fx_op_cfgs,
)
Expand Down Expand Up @@ -877,7 +877,7 @@ def output_hook(self, input, output):
if adaptor.version.release >= Version("1.13.0").release: # pragma: no cover
tmp_model = prepare_fx(tmp_model, fx_op_cfgs, example_inp)
else:
tmp_model = prepare_fx(
tmp_model = prepare_fx( # pylint: disable=E1120
tmp_model,
fx_op_cfgs,
)
Expand Down Expand Up @@ -958,7 +958,7 @@ def output_hook(self, input, output):
if adaptor.version.release >= Version("1.13.0").release: # pragma: no cover
tmp_model = prepare_fx(tmp_model, fx_op_cfgs, example_inp)
else:
tmp_model = prepare_fx(
tmp_model = prepare_fx( # pylint: disable=E1120
tmp_model,
fx_op_cfgs,
)
Expand Down
17 changes: 17 additions & 0 deletions neural_compressor/torch/algorithms/static_quant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2024 Intel Corporation

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .utility import *
from .static_quant import static_quantize
130 changes: 130 additions & 0 deletions neural_compressor/torch/algorithms/static_quant/static_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json

from neural_compressor.torch.utils import get_ipex_version

try:
import intel_extension_for_pytorch as ipex
except:
assert False, "Please install IPEX for static quantization."

import torch
from packaging.version import Version

from .utility import (
cfg_to_qconfig,
dump_model_op_stats,
get_quantizable_ops_recursively,
ipex_config_path,
simple_inference,
)

ipex_ver = get_ipex_version()


def static_quantize(model, tune_cfg, run_fn, example_inputs, inplace=True):
"""Execute the quantize process on the specified model.
Args:
model: a float model to be quantized.
tune_cfg: quantization config for ops.
run_fn: a calibration function for calibrating the model.
example_inputs: used to trace torch model.
inplace: whether to carry out model transformations in-place.
Returns:
A quantized model.
"""
model.eval()

if ipex_ver.release >= Version("1.12.0").release:
# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig

if ipex_ver.release >= Version("2.1").release:
static_qconfig = ipex.quantization.default_static_qconfig_mapping
else:
static_qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
)
if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
)
else:
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)

model.load_qconf_summary(qconf_summary=ipex_config_path)
run_fn(model)
model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)

else: # pragma: no cover
# for IPEX version < 1.12
_, cfgs, default_cfgs, fuse_ops = get_quantizable_ops_recursively(model, example_inputs)
qscheme = cfg_to_qconfig(tune_cfg, cfgs, default_cfgs, fuse_ops)
ipex_conf = ipex.quantization.QuantConf(
configure_file=ipex_config_path, qscheme=qscheme
) # pylint: disable=E1101
run_fn(model)
ipex_conf.save(ipex_config_path)
ipex_conf = ipex.quantization.QuantConf(ipex_config_path) # pylint: disable=E1101
model = ipex.quantization.convert(model, ipex_conf, example_inputs, inplace=True) # pylint: disable=E1121

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
if ipex_ver.release >= Version("1.12.0").release:
dump_model_op_stats(tune_cfg)
return model


def _ipex_post_quant_process(model, example_inputs, inplace=False):
"""Convert to a jit model.
Args:
model: a prepared model.
example_inputs: used to trace torch model.
inplace: whether to carry out model transformations in-place.
Returns:
A converted jit model.
"""
model = ipex.quantization.convert(model, inplace=inplace)
with torch.no_grad():
try:
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
else:
model = torch.jit.trace(model, example_inputs)
model = torch.jit.freeze(model.eval())
except:
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
else:
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.freeze(model.eval())
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
# At the 2nd run, the llga pass will be triggered and the model is turned into
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
simple_inference(model, example_inputs, iterations=2)
return model
Loading

0 comments on commit 191383e

Please sign in to comment.