Skip to content

Commit

Permalink
adds Keras callback to save models during QAT (#2558)
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Ernst <quic_ernst@quicinc.com>
  • Loading branch information
quic-ernst authored Nov 9, 2023
1 parent b98f200 commit 0d93665
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def _set_op_mode_parameters(self, op_mode: libpymo.TensorQuantizerOpMode):
if param_quantizer.is_enabled():
param_quantizer.quant_mode = op_mode

def export(self, path, filename_prefix, custom_objects=None):
def export(self, path, filename_prefix, custom_objects=None, convert_to_pb=True):
"""
This method exports out the quant-sim model so it is ready to be run on-target.
Specifically, the following are saved
Expand All @@ -447,7 +447,8 @@ def export(self, path, filename_prefix, custom_objects=None):

# Conversion of saved h5 model to pb model for consumption by SNPE/QNN
try:
convert_h5_model_to_pb_model(f'{model_path}.h5', custom_objects=custom_objects)
if convert_to_pb:
convert_h5_model_to_pb_model(f'{model_path}.h5', custom_objects=custom_objects)
except ValueError:
_logger.error("Could not convert h5 to frozen pb. "
"Please call export() again with custom_objects defined.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,20 @@
# @@-COPYRIGHT-END-@@
# =============================================================================
"""Quantizer utility"""
import os.path
from typing import List, Optional, Union

import numpy as np
import tensorflow as tf
from aimet_tensorflow.keras.quant_sim.qc_quantize_wrapper import QcQuantizeWrapper

from aimet_tensorflow.keras.quant_sim.tensor_quantizer import ParamPerChannelQuantizer, ParamPerTensorQuantizer, TensorQuantizer
from aimet_common.utils import AimetLogger

from aimet_tensorflow.keras.quant_sim.qc_quantize_wrapper import QcQuantizeWrapper
from aimet_tensorflow.keras.quant_sim.tensor_quantizer import ParamPerChannelQuantizer, ParamPerTensorQuantizer, \
TensorQuantizer
from aimet_tensorflow.keras.quantsim import QuantizationSimModel

_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)

def get_enabled_param_quantizers(sim: QuantizationSimModel) -> List[TensorQuantizer]:
"""
Expand Down Expand Up @@ -93,6 +98,7 @@ def enable_disable_quantizers(quantizers: List[TensorQuantizer],
for quantizer in quantizers:
quantizer.disable()


# pylint: disable=protected-access
def get_wrappers_weight_quantizer(param_quantizers: Union[List[ParamPerTensorQuantizer], List[ParamPerChannelQuantizer]]) -> \
Union[ParamPerTensorQuantizer, ParamPerChannelQuantizer, List[ParamPerTensorQuantizer], List[ParamPerChannelQuantizer]]:
Expand All @@ -119,6 +125,7 @@ def get_wrappers_weight_quantizer(param_quantizers: Union[List[ParamPerTensorQua

raise AttributeError(f"Unable to find kernel quantizer.")


# pylint: disable=protected-access
def get_wrappers_bias_quantizer(param_quantizers: Union[List[ParamPerTensorQuantizer], List[ParamPerChannelQuantizer]]) -> \
Optional[Union[ParamPerTensorQuantizer, ParamPerChannelQuantizer, List[ParamPerTensorQuantizer], List[ParamPerChannelQuantizer]]]:
Expand Down Expand Up @@ -147,6 +154,7 @@ def get_wrappers_bias_quantizer(param_quantizers: Union[List[ParamPerTensorQuant
return quantizer
return None


def model_contains_only_quantize_wrappers(model: tf.keras.Model) -> bool:
"""
Helper function to determine if a given model only contains quantize wrappers (besides InputLayers).
Expand All @@ -156,3 +164,22 @@ def model_contains_only_quantize_wrappers(model: tf.keras.Model) -> bool:
"""

return np.all(np.vectorize(lambda x: isinstance(x, (tf.keras.layers.InputLayer, QcQuantizeWrapper)))(model.layers))


class SaveModelWithoutQuantsimWrappersCallback(tf.keras.callbacks.Callback):
"""
Keras Callback Class to save QuantSim models during QAT
"""
def __init__(self, sim: QuantizationSimModel, save_path: str, filename_prefix: str, custom_objects: dict = None):
super(SaveModelWithoutQuantsimWrappersCallback, self).__init__()
self.sim = sim
self.save_path = os.path.abspath(save_path)
self.filename_prefix = filename_prefix
self.custom_objects = custom_objects

def on_epoch_end(self, epoch, logs=None):
self.sim.export(self.save_path,
f"{self.filename_prefix}_epoch_{epoch}",
self.custom_objects,
convert_to_pb=False)
_logger.info("End epoch %s; successfully exported model.", epoch)
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

from aimet_common.defs import QuantScheme, RANGE_LEARNING_SCHEMES
from aimet_tensorflow.examples.test_models import keras_model
from aimet_tensorflow.keras.utils.quantizer_utils import SaveModelWithoutQuantsimWrappersCallback
from aimet_tensorflow.keras.cross_layer_equalization import equalize_model
from aimet_tensorflow.keras.quant_sim.qc_mha_wrapper import QcQuantizableMultiHeadAttention
from aimet_tensorflow.keras.quantsim import QuantizationSimModel
Expand Down Expand Up @@ -420,18 +421,36 @@ def test_qat():
running_dense_output_quantizer_encoding_max = \
tf.keras.backend.get_value(qsim.model.layers[1].output_quantizers[0]._encoding_max)

for i in range(10):
_ = qsim.model.fit(x=rand_inp, y=rand_out, batch_size=1)
ending_weights = [tf.keras.backend.get_value(param) for
param in qsim.model.layers[1]._layer_to_wrap.weights]
new_dense_output_quantizer_encoding_max = \
tf.keras.backend.get_value(qsim.model.layers[1].output_quantizers[0]._encoding_max)
for idx, weight in enumerate(running_weights):
assert not np.array_equal(weight, ending_weights[idx])
assert np.array_equal(new_dense_output_quantizer_encoding_max,
running_dense_output_quantizer_encoding_max)
running_weights = ending_weights
running_dense_output_quantizer_encoding_max = new_dense_output_quantizer_encoding_max
with tempfile.TemporaryDirectory() as tmp_dir:
epochs = 10
save_model_callback = SaveModelWithoutQuantsimWrappersCallback(qsim, tmp_dir, "test_qat")
for i in range(epochs):
_ = qsim.model.fit(x=rand_inp, y=rand_out, batch_size=1, callbacks=save_model_callback)
ending_weights = [tf.keras.backend.get_value(param) for
param in qsim.model.layers[1]._layer_to_wrap.weights]
new_dense_output_quantizer_encoding_max = \
tf.keras.backend.get_value(qsim.model.layers[1].output_quantizers[0]._encoding_max)
for idx, weight in enumerate(running_weights):
assert not np.array_equal(weight, ending_weights[idx])
assert np.array_equal(new_dense_output_quantizer_encoding_max,
running_dense_output_quantizer_encoding_max)
running_weights = ending_weights
running_dense_output_quantizer_encoding_max = new_dense_output_quantizer_encoding_max

h5s = encodings = yamls = saved_models_folders = 0
for file in os.listdir(tmp_dir):
if file.endswith('h5'):
h5s += 1
elif file.endswith('encodings'):
encodings += 1
elif file.endswith('yaml'):
yamls += 1
else:
saved_models_folders += 1

for file_count in [h5s, encodings, yamls, saved_models_folders]:
assert file_count == 1, f"QAT Save Callback did not work"


def test_range_learning():
if version.parse(tf.version.VERSION) >= version.parse("2.00"):
Expand Down

0 comments on commit 0d93665

Please sign in to comment.