diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 9da798375af25..97b4116826a2a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -272,7 +272,7 @@ def __init__(self, ] self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max'] self._support_algo_type = [ - 'KL', 'hist', 'avg', 'mse', 'abs_max', 'min_max' + 'KL', 'hist', 'avg', 'mse', 'emd', 'abs_max', 'min_max' ] self._dynamic_quantize_op_type = ['lstm'] self._support_quantize_op_type = \ @@ -349,7 +349,7 @@ def __init__(self, # The vars for algo = avg self._quantized_var_avg = {} # The best loss of algo = mse - self._best_mse_loss = {} + self._best_calibration_loss = {} # The threshold for algo = abs_max, mse or avg self._quantized_threshold = {} @@ -408,7 +408,7 @@ def quantize(self): np.array(self._quantized_var_avg[var_name]).mean() if self._algo in ["KL", "hist"]: self._calculate_kl_hist_threshold() - if self._algo in ["KL", "abs_max", "hist", "avg", "mse"]: + if self._algo in ["KL", "abs_max", "hist", "avg", "mse", "emd"]: self._update_program() else: self._save_input_threhold() @@ -582,6 +582,8 @@ def _sampling(self): self._sample_min_max() elif self._algo == "mse": self._sample_mse() + elif self._algo == "emd": + self._sample_emd() elif self._algo in ["KL", "hist"]: self._sample_histogram() @@ -610,8 +612,8 @@ def _sample_mse(self): abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value s = 0.3 - if var_name not in self._best_mse_loss: - self._best_mse_loss[var_name] = float('inf') + if var_name not in self._best_calibration_loss: + self._best_calibration_loss[var_name] = float('inf') while s <= 1.0: scale = s * abs_max_value s += 0.02 @@ -620,8 +622,49 @@ def _sample_mse(self): np.clip(var_tensor, 0.0, scale) / scale * bins) / bins * scale mse_loss = ((var_tensor - quant_dequant_var)**2).mean() - if mse_loss <= self._best_mse_loss[var_name]: - self._best_mse_loss[var_name] = mse_loss + if mse_loss <= self._best_calibration_loss[var_name]: + self._best_calibration_loss[var_name] = mse_loss + self._quantized_threshold[var_name] = scale + + def _sample_emd(self): + if self._quantized_threshold == {}: + for var_name in self._quantized_weight_var_name: + var_tensor = _load_variable_data(self._scope, var_name) + if self._weight_quantize_type == "abs_max": + abs_max_value = float(np.max(np.abs(var_tensor))) + elif self._weight_quantize_type == "channel_wise_abs_max": + abs_max_value = [] + if self._weight_op_pairs[ + var_name] in _channelwise_quant_axis1_ops: + for i in range(var_tensor.shape[1]): + abs_max_value.append( + float(np.max(np.abs(var_tensor[:, i])))) + else: + for i in range(var_tensor.shape[0]): + abs_max_value.append( + float(np.max(np.abs(var_tensor[i])))) + self._quantized_threshold[var_name] = abs_max_value + _logger.info("EMD searching stage ...") + for var_name in self._quantized_act_var_name: + var_tensor = _load_variable_data(self._scope, var_name) + var_tensor = var_tensor.flatten() + abs_max_value = float(np.max(np.abs(var_tensor))) + abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value + s = 0.3 + if var_name not in self._best_calibration_loss: + self._best_calibration_loss[var_name] = float('inf') + while s <= 1.0: + scale = s * abs_max_value + s += 0.02 + bins = 2**(self._activation_bits - 1) - 1 + quant_dequant_var = np.round( + np.clip(var_tensor, 0.0, scale) / scale * + bins) / bins * scale + emd_loss = np.abs( + np.mean(var_tensor) - np.mean(quant_dequant_var)) + np.abs( + np.std(var_tensor) - np.std(quant_dequant_var)) + if emd_loss <= self._best_calibration_loss[var_name]: + self._best_calibration_loss[var_name] = emd_loss self._quantized_threshold[var_name] = scale def _sample_avg(self): diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py index da5c5d6dc9441..4b70f5b103778 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py @@ -244,6 +244,26 @@ def test_post_training_mse(self): quant_iterations) +class TestPostTrainingemdForMnist(TestPostTrainingQuantization): + def test_post_training_mse(self): + model_name = "mnist_model" + data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" + data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" + algo = "emd" + quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.01 + batch_size = 10 + infer_iterations = 50 + quant_iterations = 5 + self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold, batch_size, infer_iterations, + quant_iterations) + + class TestPostTrainingavgForMnist(TestPostTrainingQuantization): def test_post_training_avg(self): model_name = "mnist_model" diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 7161104861006..f83306aca1dc0 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -394,5 +394,27 @@ def test_post_training_abs_max_mobilenetv1(self): diff_threshold) +class TestPostTrainingEMDForMobilenetv1(TestPostTrainingQuantization): + def test_post_training_avg_mobilenetv1(self): + model = "MobileNet-V1" + algo = "emd" + data_urls = [ + 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' + ] + data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] + quantizable_op_type = [ + "conv2d", + "depthwise_conv2d", + "mul", + ] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + diff_threshold = 0.025 + self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, + is_full_quantize, is_use_cache_file, is_optimize_model, + diff_threshold) + + if __name__ == '__main__': unittest.main()