Skip to content

Commit

Permalink
Expose public APIs for loading/saving quantsim for backwards-compatib…
Browse files Browse the repository at this point in the history
…ility

Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu committed Dec 18, 2024
1 parent d8c67af commit 533e583
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 108 deletions.
69 changes: 67 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/_base/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from collections import OrderedDict, defaultdict
import json
import warnings
import pickle
from typing import (
Callable,
List,
Expand All @@ -66,6 +67,7 @@
from aimet_common.utils import AimetLogger, save_json_yaml, log_with_error_and_assert_if_false
from aimet_common.defs import QuantScheme, QuantizationDataType, SupportedKernelsAction, QuantDtypeBwInfo
from aimet_common.quantsim import validate_quantsim_inputs, extract_global_quantizer_args, VALID_ENCODING_VERSIONS
from aimet_common.quant_utils import get_conv_accum_bounds
from aimet_common.utils import deprecated, _red
from aimet_common import quantsim

Expand Down Expand Up @@ -1610,7 +1612,7 @@ def _load_encodings_impl(self, encodings: Mapping,
self._set_activation_encodings(activation_encodings,
strict, partial, requires_grad, allow_overwrite)

@deprecated(f"Use {load_encodings.__qualname__} instead.")
@deprecated("Use QuantizationSimModel.load_encodings instead.")
def load_and_freeze_encodings(self, encoding_path: str, ignore_when_quantizer_disabled: bool = False):
"""
Functionality to set encodings (both activation and parameter) as per the given encodings JSON file and
Expand Down Expand Up @@ -1690,7 +1692,7 @@ def _set_activation_encodings(self,
raise RuntimeError(f"Encoding import failed for module: {module_name}.\n{str(e)}") from e


@deprecated(f"Use {load_encodings.__qualname__} instead.")
@deprecated("Use QuantizationSimModel.load_encodings instead.")
def set_and_freeze_param_encodings(self, encoding_path: str):
"""
Set and freeze parameter encodings from encodings JSON file.
Expand Down Expand Up @@ -1728,3 +1730,66 @@ def run_modules_for_traced_custom_marker(self, module_list: List[torch.nn.Module
with utils.in_eval_mode(module), torch.no_grad():
marker_layer = torch.jit.trace(CustomMarker(module, name, True), dummy_input)
self._module_marker_map[name] = marker_layer




def save_checkpoint(quant_sim_model: _QuantizationSimModelInterface, file_path: str):
"""
This API provides a way for the user to save a checkpoint of the quantized model which can
be loaded at a later point to continue fine-tuning e.g.
See also load_checkpoint()
:param quant_sim_model: QuantizationSimModel to save checkpoint for
:param file_path: Path to the file where you want to save the checkpoint
:return: None
"""
with open(file_path, 'wb') as file:
pickle.dump(quant_sim_model, file)


def load_checkpoint(file_path: str) -> _QuantizationSimModelInterface:
"""
Load the quantized model
:param file_path: Path to the file where you want to save the checkpoint
:return: A new instance of the QuantizationSimModel created after loading the checkpoint
"""
with open(file_path, 'rb') as file:
sim = pickle.load(file)
return sim


@deprecated("check_accumulator_overflow API will be removed in the future releases.")
def check_accumulator_overflow(model: torch.nn.Module, quant_bw: int, accum_bw: int):
"""
Checks for any potential for accumulator overflow across all the layers of the given model
:param model: Model
:param quant_bw: Bitwidth the layers are quantized at
:param accum_bw: Bitwidth of the accumulator
:return: Name of the layer with the most accumulator range used and range used
"""

most_accum_range_used = 0
most_accum_range_used_layer = None

for layer_name, layer in model.named_modules():

if isinstance(layer, torch.nn.Conv2d):
was_accum_range_exceeded, accum_range_used = get_conv_accum_bounds(layer.weight.detach().numpy(),
quant_bw, accum_bw)
if accum_range_used > most_accum_range_used:
most_accum_range_used = accum_range_used
most_accum_range_used_layer = layer_name

if was_accum_range_exceeded:
logger.info('Possible accumulator overflow for layer: %s', layer_name)

if most_accum_range_used < 1:
logger.info('No overflow detected. Layer %s had the most accumulator range used: %f%%',
most_accum_range_used_layer, most_accum_range_used * 100)
else:
logger.info('Overflow detected. Layer %s had the most accumulator range used: %f%%',
most_accum_range_used_layer, most_accum_range_used * 100)

return most_accum_range_used_layer, most_accum_range_used
15 changes: 15 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
QuantizationSimModel,
QuantParams,
ExportableQuantModule,
save_checkpoint,
load_checkpoint,
check_accumulator_overflow,
load_encodings_to_sim,
compute_encodings_for_sims,
)

from .utils import _warn_replaced_in_v2
Expand All @@ -56,13 +61,23 @@
QuantizationSimModel,
QuantParams,
ExportableQuantModule,
save_checkpoint,
load_checkpoint,
check_accumulator_overflow,
load_encodings_to_sim,
compute_encodings_for_sims,
)


__all__ = [
'QuantizationSimModel',
'QuantParams',
'ExportableQuantModule',
'save_checkpoint',
'load_checkpoint',
'check_accumulator_overflow',
'load_encodings_to_sim',
'compute_encodings_for_sims',
]

undefined = set(__all__) - set(globals())
Expand Down
75 changes: 10 additions & 65 deletions TrainingExtensions/torch/src/python/aimet_torch/v1/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@
import os
import io
import copy
import pickle
from typing import Tuple, List, Union, Dict, Callable, Optional, Any
import torch

from aimet_common.utils import AimetLogger
from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_common.quant_utils import get_conv_accum_bounds
from aimet_common.utils import deprecated

from aimet_torch.v1.qc_quantize_op import QcQuantizeStandAloneBase, QcQuantizeWrapper, QcQuantizeOpMode, \
Expand All @@ -65,12 +63,20 @@
unquantizable_modules,
QuantParams,
ExportableQuantModule,
save_checkpoint,
load_checkpoint,
check_accumulator_overflow,
)

__all__ = [
'QuantizationSimModel',
'QuantParams',
'ExportableQuantModule',
'save_checkpoint',
'load_checkpoint',
'check_accumulator_overflow',
'load_encodings_to_sim',
'compute_encodings_for_sims',
]


Expand Down Expand Up @@ -650,69 +656,8 @@ def _validate_torchquantizer(quant_sim_model):



def save_checkpoint(quant_sim_model: QuantizationSimModel, file_path: str):
"""
This API provides a way for the user to save a checkpoint of the quantized model which can
be loaded at a later point to continue fine-tuning e.g.
See also load_checkpoint()
:param quant_sim_model: QuantizationSimModel to save checkpoint for
:param file_path: Path to the file where you want to save the checkpoint
:return: None
"""
with open(file_path, 'wb') as file:
pickle.dump(quant_sim_model, file)


def load_checkpoint(file_path: str) -> QuantizationSimModel:
"""
Load the quantized model
:param file_path: Path to the file where you want to save the checkpoint
:return: A new instance of the QuantizationSimModel created after loading the checkpoint
"""
with open(file_path, 'rb') as file:
sim = pickle.load(file)
return sim


@deprecated("check_accumulator_overflow API will be removed in the future releases.")
def check_accumulator_overflow(model: torch.nn.Module, quant_bw: int, accum_bw: int):
"""
Checks for any potential for accumulator overflow across all the layers of the given model
:param model: Model
:param quant_bw: Bitwidth the layers are quantized at
:param accum_bw: Bitwidth of the accumulator
:return: Name of the layer with the most accumulator range used and range used
"""

most_accum_range_used = 0
most_accum_range_used_layer = None

for layer_name, layer in model.named_modules():

if isinstance(layer, torch.nn.Conv2d):
was_accum_range_exceeded, accum_range_used = get_conv_accum_bounds(layer.weight.detach().numpy(),
quant_bw, accum_bw)
if accum_range_used > most_accum_range_used:
most_accum_range_used = accum_range_used
most_accum_range_used_layer = layer_name

if was_accum_range_exceeded:
logger.info('Possible accumulator overflow for layer: %s', layer_name)

if most_accum_range_used < 1:
logger.info('No overflow detected. Layer %s had the most accumulator range used: %f%%',
most_accum_range_used_layer, most_accum_range_used * 100)
else:
logger.info('Overflow detected. Layer %s had the most accumulator range used: %f%%',
most_accum_range_used_layer, most_accum_range_used * 100)

return most_accum_range_used_layer, most_accum_range_used


@deprecated(f"Use {QuantizationSimModel.load_encodings.__qualname__} instead.")
def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, pytorch_encoding_path: str):
@deprecated("Use QuantizationSimModel.load_encodings instead.")
def load_encodings_to_sim(quant_sim_model: _QuantizationSimModelBase, pytorch_encoding_path: str):
"""
Loads the saved encodings to quant sim model. The encoding filename to load should end in _torch.encodings,
generated as part of quantsim export.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@
unquantizable_modules,
QuantParams,
ExportableQuantModule,
save_checkpoint,
load_checkpoint,
check_accumulator_overflow,
)
from aimet_torch.v2 import nn as aimet_nn
from aimet_torch.v2.nn import BaseQuantizationMixin, QuantizationMixin
Expand All @@ -68,6 +71,11 @@
'QuantizationSimModel',
'QuantParams',
'ExportableQuantModule',
'save_checkpoint',
'load_checkpoint',
'check_accumulator_overflow',
'load_encodings_to_sim',
'compute_encodings_for_sims',
]

unquantizable_modules = (QuantizerBase, *unquantizable_modules)
Expand Down Expand Up @@ -522,3 +530,50 @@ def _remove_quantization_wrappers(cls, starting_module, list_of_modules_to_exclu
# Recursively call children modules if present
if not utils.is_leaf_module(module):
cls._remove_quantization_wrappers(module, list_of_modules_to_exclude)


@deprecated("Use QuantizationSimModel.load_encodings instead.")
def load_encodings_to_sim(quant_sim_model: _QuantizationSimModelBase, pytorch_encoding_path: str):
"""
Loads the saved encodings to quant sim model. The encoding filename to load should end in _torch.encodings,
generated as part of quantsim export.
:param quant_sim_model: Quantized model to load encodings for. Note: The model configuration should be the same as
when encodings were exported.
:param pytorch_encoding_path: Path of the encodings file to load.
"""
quant_sim_model.load_encodings(pytorch_encoding_path,
strict=True,
partial=False,
requires_grad=None,
allow_overwrite=None)


def compute_encodings_for_sims(sim_list: Sequence[QuantizationSimModel], forward_pass_callback: Callable,
forward_pass_callback_args: Any):
"""
Compute encodings for a list of QuantSims.
:param sim_list: List of QuantSims to compute encodings for.
:param forward_pass_callback: A callback function that simply runs forward passes on the models. This callback
function should use representative data for the forward pass, so the calculated encodings work for all
data samples. This callback internally chooses the number of data samples it wants to use for calculating
encodings.
The callback expects exactly two inputs:
- List of models which are involved in the forward pass. The models are taken directly from calling
sim.model for each sim in sim_list, passed in the same order in which the sims appear in sim_list.
- Forward pass callback args
:param forward_pass_callback_args: These argument(s) are passed to the forward_pass_callback as-is. Up to
the user to determine the type of this parameter. E.g. could be simply an integer representing the number
of data samples to use. Or could be a tuple of parameters or an object representing something more complex.
If set to None, forward_pass_callback will be invoked with no parameters.
"""
ctx_managers = [torch.no_grad()]
for sim in sim_list:
ctx_managers.append(utils.in_eval_mode(sim.model))
ctx_managers.append(aimet_nn.compute_encodings(sim.model))

with contextlib.ExitStack() as stack:
for mgr in ctx_managers:
stack.enter_context(mgr)
_ = forward_pass_callback([sim.model for sim in sim_list], forward_pass_callback_args)
Loading

0 comments on commit 533e583

Please sign in to comment.