Skip to content

Commit

Permalink
add EMD method of post_quant
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill committed Mar 10, 2022
1 parent befa78e commit fe5ef53
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def __init__(self,
]
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
self._support_algo_type = [
'KL', 'hist', 'avg', 'mse', 'abs_max', 'min_max'
'KL', 'hist', 'avg', 'mse', 'emd', 'abs_max', 'min_max'
]
self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = \
Expand Down Expand Up @@ -349,7 +349,7 @@ def __init__(self,
# The vars for algo = avg
self._quantized_var_avg = {}
# The best loss of algo = mse
self._best_mse_loss = {}
self._best_calibration_loss = {}
# The threshold for algo = abs_max, mse or avg
self._quantized_threshold = {}

Expand Down Expand Up @@ -408,7 +408,7 @@ def quantize(self):
np.array(self._quantized_var_avg[var_name]).mean()
if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold()
if self._algo in ["KL", "abs_max", "hist", "avg", "mse"]:
if self._algo in ["KL", "abs_max", "hist", "avg", "mse", "emd"]:
self._update_program()
else:
self._save_input_threhold()
Expand Down Expand Up @@ -582,6 +582,8 @@ def _sampling(self):
self._sample_min_max()
elif self._algo == "mse":
self._sample_mse()
elif self._algo == "emd":
self._sample_emd()
elif self._algo in ["KL", "hist"]:
self._sample_histogram()

Expand Down Expand Up @@ -610,8 +612,8 @@ def _sample_mse(self):
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
if var_name not in self._best_mse_loss:
self._best_mse_loss[var_name] = float('inf')
if var_name not in self._best_calibration_loss:
self._best_calibration_loss[var_name] = float('inf')
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
Expand All @@ -620,8 +622,49 @@ def _sample_mse(self):
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
mse_loss = ((var_tensor - quant_dequant_var)**2).mean()
if mse_loss <= self._best_mse_loss[var_name]:
self._best_mse_loss[var_name] = mse_loss
if mse_loss <= self._best_calibration_loss[var_name]:
self._best_calibration_loss[var_name] = mse_loss
self._quantized_threshold[var_name] = scale

def _sample_emd(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
abs_max_value = []
if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[:, i]))))
else:
for i in range(var_tensor.shape[0]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value
_logger.info("EMD searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
if var_name not in self._best_calibration_loss:
self._best_calibration_loss[var_name] = float('inf')
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2**(self._activation_bits - 1) - 1
quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
emd_loss = np.abs(
np.mean(var_tensor) - np.mean(quant_dequant_var)) + np.abs(
np.std(var_tensor) - np.std(quant_dequant_var))
if emd_loss <= self._best_calibration_loss[var_name]:
self._best_calibration_loss[var_name] = emd_loss
self._quantized_threshold[var_name] = scale

def _sample_avg(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,26 @@ def test_post_training_mse(self):
quant_iterations)


class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self):
model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "emd"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)


class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
def test_post_training_avg(self):
model_name = "mnist_model"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,5 +394,27 @@ def test_post_training_abs_max_mobilenetv1(self):
diff_threshold)


class TestPostTrainingEMDForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_avg_mobilenetv1(self):
model = "MobileNet-V1"
algo = "emd"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)


if __name__ == '__main__':
unittest.main()

1 comment on commit fe5ef53

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.