Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add EMD method of post_quant #40421

Merged
merged 1 commit into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def __init__(self,
]
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
self._support_algo_type = [
'KL', 'hist', 'avg', 'mse', 'abs_max', 'min_max'
'KL', 'hist', 'avg', 'mse', 'emd', 'abs_max', 'min_max'
]
self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = \
Expand Down Expand Up @@ -349,7 +349,7 @@ def __init__(self,
# The vars for algo = avg
self._quantized_var_avg = {}
# The best loss of algo = mse
self._best_mse_loss = {}
self._best_calibration_loss = {}
# The threshold for algo = abs_max, mse or avg
self._quantized_threshold = {}

Expand Down Expand Up @@ -408,7 +408,7 @@ def quantize(self):
np.array(self._quantized_var_avg[var_name]).mean()
if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold()
if self._algo in ["KL", "abs_max", "hist", "avg", "mse"]:
if self._algo in ["KL", "abs_max", "hist", "avg", "mse", "emd"]:
self._update_program()
else:
self._save_input_threhold()
Expand Down Expand Up @@ -582,6 +582,8 @@ def _sampling(self):
self._sample_min_max()
elif self._algo == "mse":
self._sample_mse()
elif self._algo == "emd":
self._sample_emd()
elif self._algo in ["KL", "hist"]:
self._sample_histogram()

Expand Down Expand Up @@ -610,8 +612,8 @@ def _sample_mse(self):
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
if var_name not in self._best_mse_loss:
self._best_mse_loss[var_name] = float('inf')
if var_name not in self._best_calibration_loss:
self._best_calibration_loss[var_name] = float('inf')
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
Expand All @@ -620,8 +622,49 @@ def _sample_mse(self):
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
mse_loss = ((var_tensor - quant_dequant_var)**2).mean()
if mse_loss <= self._best_mse_loss[var_name]:
self._best_mse_loss[var_name] = mse_loss
if mse_loss <= self._best_calibration_loss[var_name]:
self._best_calibration_loss[var_name] = mse_loss
self._quantized_threshold[var_name] = scale

def _sample_emd(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
abs_max_value = []
if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[:, i]))))
else:
for i in range(var_tensor.shape[0]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value
_logger.info("EMD searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
if var_name not in self._best_calibration_loss:
self._best_calibration_loss[var_name] = float('inf')
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2**(self._activation_bits - 1) - 1
quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
emd_loss = np.abs(
np.mean(var_tensor) - np.mean(quant_dequant_var)) + np.abs(
np.std(var_tensor) - np.std(quant_dequant_var))
if emd_loss <= self._best_calibration_loss[var_name]:
self._best_calibration_loss[var_name] = emd_loss
self._quantized_threshold[var_name] = scale

def _sample_avg(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,26 @@ def test_post_training_mse(self):
quant_iterations)


class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self):
model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "emd"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)


class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
def test_post_training_avg(self):
model_name = "mnist_model"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,5 +394,27 @@ def test_post_training_abs_max_mobilenetv1(self):
diff_threshold)


class TestPostTrainingEMDForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_avg_mobilenetv1(self):
model = "MobileNet-V1"
algo = "emd"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)


if __name__ == '__main__':
unittest.main()