From ab770f85b00fd4ac3d1129f04103f65b039d4b73 Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Thu, 27 Apr 2023 02:11:20 +0000 Subject: [PATCH] Adds the set_input_type API for specifying input data types --- test/ir/inference/auto_scan_test.py | 142 +++++++++--------- test/ir/inference/program_config.py | 19 ++- .../inference/test_trt_convert_bitwise_not.py | 4 +- 3 files changed, 93 insertions(+), 72 deletions(-) diff --git a/test/ir/inference/auto_scan_test.py b/test/ir/inference/auto_scan_test.py index 7c763436750fb..f0f094a214694 100755 --- a/test/ir/inference/auto_scan_test.py +++ b/test/ir/inference/auto_scan_test.py @@ -662,6 +662,9 @@ def create_inference_config(self, use_trt=True) -> paddle_infer.Config: ) return config + def get_avalible_input_type(self) -> List[np.dtype]: + return [np.float32] + def assert_tensors_near( self, atol: float, @@ -759,78 +762,81 @@ def random_to_skip(): nodes_num, threshold, ) in self.sample_predictor_configs(prog_config): - if os.path.exists(self.cache_dir): - shutil.rmtree(self.cache_dir) - - if isinstance(threshold, float): - atol = threshold - rtol = 1e-8 - elif isinstance(threshold, list) or isinstance( - threshold, tuple - ): - atol = threshold[0] - rtol = threshold[1] - else: - raise NotImplementedError - - is_fp8 = ( - pred_config.tensorrt_precision_mode() - == paddle_infer.PrecisionType.Int8 - ) - if (not is_fp8 and quant) or (is_fp8 and not quant): - continue - - ignore_flag = False - for teller, reason, note in self.ignore_cases: - if teller(prog_config, pred_config): - ignore_flag = True - if reason == IgnoreReasons.TRT_NOT_IMPLEMENTED: - self.ignore_log( - f"[TRT_NOT_IMPLEMENTED] {note} vs {self.inference_config_str(pred_config)}" - ) - elif reason == IgnoreReasons.TRT_NOT_SUPPORT: - self.ignore_log( - f"[TRT_NOT_SUPPORT] {note} vs {self.inference_config_str(pred_config)}" - ) - else: - raise NotImplementedError - break - - if ignore_flag: - continue + for input_type in self.get_avalible_input_type(): + if os.path.exists(self.cache_dir): + shutil.rmtree(self.cache_dir) - try: - pred_config_deserialize = paddle_infer.Config(pred_config) - trt_result = self.run_test_config( - model, params, prog_config, pred_config, feed_data - ) - self.assert_tensors_near( - atol, rtol, trt_result, baseline_result + if isinstance(threshold, float): + atol = threshold + rtol = 1e-8 + elif isinstance(threshold, list) or isinstance( + threshold, tuple + ): + atol = threshold[0] + rtol = threshold[1] + else: + raise NotImplementedError + + is_fp8 = ( + pred_config.tensorrt_precision_mode() + == paddle_infer.PrecisionType.Int8 ) - trt_engine_num, paddle_op_num = nodes_num - self.assert_op_size(trt_engine_num, paddle_op_num) - - # deserialize test - if trt_engine_num > 0: - self.run_test_config( - model, - params, - prog_config, - pred_config_deserialize, - feed_data, + if (not is_fp8 and quant) or (is_fp8 and not quant): + continue + + ignore_flag = False + for teller, reason, note in self.ignore_cases: + if teller(prog_config, pred_config): + ignore_flag = True + if reason == IgnoreReasons.TRT_NOT_IMPLEMENTED: + self.ignore_log( + f"[TRT_NOT_IMPLEMENTED] {note} vs {self.inference_config_str(pred_config)}" + ) + elif reason == IgnoreReasons.TRT_NOT_SUPPORT: + self.ignore_log( + f"[TRT_NOT_SUPPORT] {note} vs {self.inference_config_str(pred_config)}" + ) + else: + raise NotImplementedError + break + + if ignore_flag: + continue + + try: + pred_config_deserialize = paddle_infer.Config( + pred_config ) + trt_result = self.run_test_config( + model, params, prog_config, pred_config, feed_data + ) + self.assert_tensors_near( + atol, rtol, trt_result, baseline_result + ) + trt_engine_num, paddle_op_num = nodes_num + self.assert_op_size(trt_engine_num, paddle_op_num) + + # deserialize test + if trt_engine_num > 0: + self.run_test_config( + model, + params, + prog_config, + pred_config_deserialize, + feed_data, + ) - self.success_log(f"program_config: {prog_config}") - self.success_log( - f"predictor_config: {self.inference_config_str(pred_config)}" - ) - except Exception as e: - self.fail_log(f"program_config: {prog_config}") - self.fail_log( - f"predictor_config: {self.inference_config_str(pred_config)}" - ) - self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m") - all_passes = False + self.success_log(f"program_config: {prog_config}") + self.success_log( + f"predictor_config: {self.inference_config_str(pred_config)}" + ) + except Exception as e: + self.fail_log(f"program_config: {prog_config}") + self.fail_log( + f"predictor_config: {self.inference_config_str(pred_config)}" + ) + self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m") + all_passes = False self.assertTrue(all_passes) diff --git a/test/ir/inference/program_config.py b/test/ir/inference/program_config.py index 91670cb62c92e..3dfd9b3f95877 100644 --- a/test/ir/inference/program_config.py +++ b/test/ir/inference/program_config.py @@ -54,8 +54,8 @@ def __init__( if data_gen is not None: self.data_gen = data_gen self.data = data_gen() - self.dtype = data_gen().dtype - self.shape = data_gen().shape + self.dtype = self.data.dtype + self.shape = self.data.shape else: assert ( shape is not None @@ -67,6 +67,11 @@ def __init__( def __repr__(self): return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype}) + def astype(self, type: np.dtype): + self.data = self.data.astype(type) + self.dtype = self.data.dtype + return self + class VarType(enum.Enum): LOD_TENSOR = 1 @@ -270,6 +275,16 @@ def __repr__(self): return log_str + def set_input_type(self, type: np.dtype): + for inp in self.inputs.values(): + inp.astype(type) + for weight in self.weights.values(): + weight.astype(type) + return self + + def get_input_type(self) -> np.dtype: + return next(iter(self.inputs.values())).dtype + def create_fake_model(program_config): '''Create a Paddle model(in memory) according to the given config.''' diff --git a/test/ir/inference/test_trt_convert_bitwise_not.py b/test/ir/inference/test_trt_convert_bitwise_not.py index d779e1fce5567..49f00b52237ed 100644 --- a/test/ir/inference/test_trt_convert_bitwise_not.py +++ b/test/ir/inference/test_trt_convert_bitwise_not.py @@ -103,11 +103,11 @@ def generate_trt_nodes_num(attrs, dynamic_shape): ver = paddle_infer.get_trt_compile_version() trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 if trt_version >= 8400: - if self.dims == 1 and not dynamic_shape: + if self.dims == 1: return 0, 3 return 1, 2 else: - if (self.dims == 1 and not dynamic_shape) or ( + if self.dims <= 2 or ( program_config.inputs['input_data'].dtype in ['bool', 'int8', 'uint8'] ):