Skip to content

Commit

Permalink
refactor: articulate imports (#3657)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Dec 18, 2024
1 parent 86daca1 commit e213b1f
Show file tree
Hide file tree
Showing 22 changed files with 240 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@
from tqdm import tqdm

from aimet_torch import utils
from aimet_torch._base.adaround.adaround_weight import AdaroundParameters
from aimet_torch.cross_layer_equalization import equalize_model
from aimet_torch._base.quantsim import _QuantizationSimModelInterface
from aimet_torch.batch_norm_fold import fold_all_batch_norms
from aimet_torch.utils import in_eval_mode
from aimet_torch.onnx_utils import OnnxExportApiArgs
from aimet_torch.model_preparer import prepare_model
Expand Down Expand Up @@ -830,7 +832,7 @@ def eval_callback_wrapper(model: torch.nn.Module, *args, **kwargs) -> float:
batch_size = self.data_loader.batch_size or 1
num_batches = math.ceil(num_samples / batch_size)
num_batches = min(num_batches, len(self.data_loader))
self.adaround_params = self._get_adaround_parameters(self.data_loader, num_batches)
self.adaround_params = AdaroundParameters(self.data_loader, num_batches)

self._export_kwargs = dict(
onnx_export_args=OnnxExportApiArgs(),
Expand All @@ -857,12 +859,6 @@ def eval_callback_wrapper(model: torch.nn.Module, *args, **kwargs) -> float:
def _get_adaround():
""" returns AdaRound """

@staticmethod
@abc.abstractmethod
def _get_adaround_parameters(data_loader, num_batches):
""" Returns AdaroundParameters(data_loader, num_batches) """


def _evaluate_model_performance(self, model) -> float:
"""
Evaluate the model performance.
Expand Down Expand Up @@ -1042,11 +1038,6 @@ def _create_quantsim_and_encodings( # pylint: disable=too-many-arguments, too-ma
def _get_quantsim(model, dummy_input, **kwargs):
""" Returns QuantizationSimModel(model, dummy_input, **kwargs) """

@staticmethod
@abc.abstractmethod
def _fold_all_batch_norms(*args, **kwargs):
...

@abc.abstractmethod
def _configure_quantsim(self, # pylint: disable=too-many-arguments
sim,
Expand Down Expand Up @@ -1093,7 +1084,7 @@ def _apply_batchnorm_folding(self, model: torch.nn.Module)\
:return: Output model and folded pairs.
"""
model = copy.deepcopy(model)
folded_pairs = self._fold_all_batch_norms(model, None, self.dummy_input)
folded_pairs = fold_all_batch_norms(model, None, self.dummy_input)
return model, folded_pairs

@cache.mark("cle")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
from ..utils import _get_default_api

if _get_default_api() == "v1":
from ..v1.adaround.adaround_weight import * # pylint: disable=wildcard-import, unused-wildcard-import
from ..v1.adaround.adaround_weight import (
Adaround,
AdaroundParameters,
AdaroundSupportedModules,
)

from ..utils import _warn_replaced_in_v2
from ..v1.adaround import adaround_weight as _v1_api
Expand All @@ -48,4 +52,20 @@
v2_new_api=_v2_api.__name__,
v1_legacy_api=_v1_api.__name__)
else:
from ..v2.adaround.adaround_weight import * # pylint: disable=wildcard-import, unused-wildcard-import
from ..v2.adaround.adaround_weight import (
Adaround,
AdaroundParameters,
AdaroundSupportedModules,
)


__all__ = [
'Adaround',
'AdaroundParameters',
'AdaroundSupportedModules',
]

undefined = set(__all__) - set(globals())
assert not undefined, \
f"The following attributes are undefined: {list(undefined)}"
del undefined
15 changes: 13 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/auto_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from .utils import _get_default_api

if _get_default_api() == "v1":
from .v1.auto_quant import * # pylint: disable=wildcard-import, unused-wildcard-import
from .v1.auto_quant import AutoQuant, AutoQuantWithAutoMixedPrecision

from .utils import _warn_replaced_in_v2
from .v1 import auto_quant as _v1_api
Expand All @@ -48,4 +48,15 @@
v2_new_api=_v2_api.__name__,
v1_legacy_api=_v1_api.__name__)
else:
from .v2.auto_quant import * # pylint: disable=wildcard-import, unused-wildcard-import
from .v2.auto_quant import AutoQuant, AutoQuantWithAutoMixedPrecision


__all__ = [
'AutoQuant',
'AutoQuantWithAutoMixedPrecision',
]

undefined = set(__all__) - set(globals())
assert not undefined, \
f"The following attributes are undefined: {list(undefined)}"
del undefined
34 changes: 31 additions & 3 deletions TrainingExtensions/torch/src/python/aimet_torch/batch_norm_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,17 @@
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Alias to v1/v2 batch_norm_fold """

from .utils import _get_default_api

if _get_default_api() == "v1":
from .v1.batch_norm_fold import * # pylint: disable=wildcard-import, unused-wildcard-import
from .v1.batch_norm_fold import (
fold_all_batch_norms,
fold_all_batch_norms_to_scale,
fold_given_batch_norms,
_is_valid_bn_fold,
_find_all_batch_norms_to_fold,
find_standalone_batchnorm_ops,
)

from .utils import _warn_replaced_in_v2
from .v1 import batch_norm_fold as _v1_batch_norm_fold
Expand All @@ -49,4 +55,26 @@
v2_new_api=_v2_batch_norm_fold.__name__,
v1_legacy_api=_v1_batch_norm_fold.__name__)
else:
from .v2.batch_norm_fold import * # pylint: disable=wildcard-import, unused-wildcard-import
from .v2.batch_norm_fold import (
fold_all_batch_norms,
fold_all_batch_norms_to_scale,
fold_given_batch_norms,
_is_valid_bn_fold,
_find_all_batch_norms_to_fold,
find_standalone_batchnorm_ops,
)


__all__ = [
"fold_all_batch_norms",
"fold_all_batch_norms_to_scale",
"fold_given_batch_norms",
"_is_valid_bn_fold",
"_find_all_batch_norms_to_fold",
"find_standalone_batchnorm_ops",
]

undefined = set(__all__) - set(globals())
assert not undefined, \
f"The following attributes are undefined: {list(undefined)}"
del undefined
14 changes: 12 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .utils import _get_default_api

if _get_default_api() == "v1":
from .v1.mixed_precision import * # pylint: disable=wildcard-import, unused-wildcard-import
from .v1.mixed_precision import choose_mixed_precision

from .utils import _warn_replaced_in_v2
from .v1 import mixed_precision as _v1_api
Expand All @@ -49,4 +49,14 @@
v2_new_api=_v2_api.__name__,
v1_legacy_api=_v1_api.__name__)
else:
from .v2.mixed_precision import * # pylint: disable=wildcard-import, unused-wildcard-import
from .v2.mixed_precision import choose_mixed_precision


__all__ = [
'choose_mixed_precision',
]

undefined = set(__all__) - set(globals())
assert not undefined, \
f"The following attributes are undefined: {list(undefined)}"
del undefined
14 changes: 12 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/quant_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from .utils import _get_default_api

if _get_default_api() == "v1":
from .v1.quant_analyzer import * # pylint: disable=wildcard-import, unused-wildcard-import
from .v1.quant_analyzer import QuantAnalyzer

from .utils import _warn_replaced_in_v2
from .v1 import quant_analyzer as _v1_api
Expand All @@ -48,4 +48,14 @@
v2_new_api=_v2_api.__name__,
v1_legacy_api=_v1_api.__name__)
else:
from .v2.quant_analyzer import * # pylint: disable=wildcard-import, unused-wildcard-import
from .v2.quant_analyzer import QuantAnalyzer


__all__ = [
'QuantAnalyzer',
]

undefined = set(__all__) - set(globals())
assert not undefined, \
f"The following attributes are undefined: {list(undefined)}"
del undefined
27 changes: 22 additions & 5 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: disable=wildcard-import, unused-wildcard-import, unused-import
""" Alias to v1/v2 quantsim """

from .utils import _get_default_api
from ._base.quantsim import QuantParams

if _get_default_api() == "v1":
from .v1.quantsim import *
from .v1.quantsim import (
QuantizationSimModel,
QuantParams,
ExportableQuantModule,
)

from .utils import _warn_replaced_in_v2
from .v1 import quantsim as _v1_quantsim
Expand All @@ -51,4 +52,20 @@
v2_new_api=_v2_quantsim.__name__,
v1_legacy_api=_v1_quantsim.__name__)
else:
from .v2.quantsim import *
from .v2.quantsim import (
QuantizationSimModel,
QuantParams,
ExportableQuantModule,
)


__all__ = [
'QuantizationSimModel',
'QuantParams',
'ExportableQuantModule',
]

undefined = set(__all__) - set(globals())
assert not undefined, \
f"The following attributes are undefined: {list(undefined)}"
del undefined
30 changes: 28 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/seq_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@
from .utils import _get_default_api

if _get_default_api() == "v1":
from .v1.seq_mse import * # pylint: disable=wildcard-import, unused-wildcard-import
from .v1.seq_mse import (
SequentialMse,
SeqMseParams,
apply_seq_mse,
get_candidates,
optimize_module,
)

from .utils import _warn_replaced_in_v2
from .v1 import seq_mse as _v1_api
Expand All @@ -49,4 +55,24 @@
v2_new_api=_v2_api.__name__,
v1_legacy_api=_v1_api.__name__)
else:
from .v2.seq_mse import * # pylint: disable=wildcard-import, unused-wildcard-import
from .v2.seq_mse import (
SequentialMse,
SeqMseParams,
apply_seq_mse,
get_candidates,
optimize_module,
)


__all__ = [
'SequentialMse',
'SeqMseParams',
'apply_seq_mse',
'get_candidates',
'optimize_module',
]

undefined = set(__all__) - set(globals())
assert not undefined, \
f"The following attributes are undefined: {list(undefined)}"
del undefined
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

from aimet_torch import utils
from aimet_torch.save_utils import SaveUtils
from aimet_torch._base.adaround.adaround_weight import ( # pylint: disable=unused-import
from aimet_torch._base.adaround.adaround_weight import (
AdaroundBase,
AdaroundParameters,
AdaroundSupportedModules,
Expand All @@ -56,6 +56,13 @@
from aimet_torch.v1.qc_quantize_op import StaticGridQuantWrapper, QcQuantizeOpMode
from aimet_torch.v1.adaround.adaround_wrapper import AdaroundWrapper


__all__ = [
'Adaround',
'AdaroundParameters',
'AdaroundSupportedModules',
]

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)


Expand Down
18 changes: 8 additions & 10 deletions TrainingExtensions/torch/src/python/aimet_torch/v1/auto_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@
_MixedPrecisionResult,
ParetoFrontType,
)
from aimet_torch.v1.adaround.adaround_weight import Adaround, AdaroundParameters
from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms
from aimet_torch.v1.adaround.adaround_weight import Adaround
from aimet_torch._base.adaround.adaround_weight import AdaroundParameters
from aimet_torch.v1.quantsim import QuantizationSimModel
from aimet_torch.utils import get_all_quantizers
from aimet_torch.onnx_utils import OnnxExportApiArgs
Expand All @@ -76,6 +76,12 @@
)


__all__ = [
'AutoQuant',
'AutoQuantWithAutoMixedPrecision',
]


_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.AutoQuant)


Expand Down Expand Up @@ -142,14 +148,6 @@ def _get_adaround():
""" returns AdaRound """
return Adaround

@staticmethod
def _fold_all_batch_norms(*args, **kwargs):
return fold_all_batch_norms(*args, **kwargs)

@staticmethod
def _get_adaround_parameters(data_loader, num_batches):
return AdaroundParameters(data_loader, num_batches)

@staticmethod
def _get_quantsim(model, dummy_input, **kwargs):
return QuantizationSimModel(model, dummy_input, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@
from aimet_torch.v1.tensor_quantizer import LearnedGridTensorQuantizer
from aimet_torch._base.batch_norm_fold import BatchNormFoldBase, _BatchNormFoldingNotSupported

__all__ = [
"fold_all_batch_norms",
"fold_all_batch_norms_to_scale",
"fold_given_batch_norms",
"_is_valid_bn_fold",
"_find_all_batch_norms_to_fold",
"find_standalone_batchnorm_ops",
]

_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.BatchNormFolding)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,20 @@
from aimet_torch.v1.qc_quantize_recurrent import QcQuantizeRecurrent
from aimet_torch.quantsim_config.builder import LazyQuantizeWrapper
from aimet_torch.v1._builder import _V1LazyQuantizeWrapper
from aimet_torch._base.quantsim import ( # pylint: disable=unused-import
from aimet_torch._base.quantsim import (
_QuantizationSimModelBase,
_QuantizedModuleProtocol,
unquantizable_modules,
QuantParams,
ExportableQuantModule,
)

__all__ = [
'QuantizationSimModel',
'QuantParams',
'ExportableQuantModule',
]


logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)

Expand Down
Loading

0 comments on commit e213b1f

Please sign in to comment.