Skip to content

Commit

Permalink
Adds the set_input_type API for specifying input data types
Browse files Browse the repository at this point in the history
  • Loading branch information
eee4017 authored and Frank Lin (Engrg-Hardware 1) committed May 19, 2023
1 parent 9092ddc commit ab770f8
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 72 deletions.
142 changes: 74 additions & 68 deletions test/ir/inference/auto_scan_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,9 @@ def create_inference_config(self, use_trt=True) -> paddle_infer.Config:
)
return config

def get_avalible_input_type(self) -> List[np.dtype]:
return [np.float32]

def assert_tensors_near(
self,
atol: float,
Expand Down Expand Up @@ -759,78 +762,81 @@ def random_to_skip():
nodes_num,
threshold,
) in self.sample_predictor_configs(prog_config):
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)

if isinstance(threshold, float):
atol = threshold
rtol = 1e-8
elif isinstance(threshold, list) or isinstance(
threshold, tuple
):
atol = threshold[0]
rtol = threshold[1]
else:
raise NotImplementedError

is_fp8 = (
pred_config.tensorrt_precision_mode()
== paddle_infer.PrecisionType.Int8
)
if (not is_fp8 and quant) or (is_fp8 and not quant):
continue

ignore_flag = False
for teller, reason, note in self.ignore_cases:
if teller(prog_config, pred_config):
ignore_flag = True
if reason == IgnoreReasons.TRT_NOT_IMPLEMENTED:
self.ignore_log(
f"[TRT_NOT_IMPLEMENTED] {note} vs {self.inference_config_str(pred_config)}"
)
elif reason == IgnoreReasons.TRT_NOT_SUPPORT:
self.ignore_log(
f"[TRT_NOT_SUPPORT] {note} vs {self.inference_config_str(pred_config)}"
)
else:
raise NotImplementedError
break

if ignore_flag:
continue
for input_type in self.get_avalible_input_type():
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)

try:
pred_config_deserialize = paddle_infer.Config(pred_config)
trt_result = self.run_test_config(
model, params, prog_config, pred_config, feed_data
)
self.assert_tensors_near(
atol, rtol, trt_result, baseline_result
if isinstance(threshold, float):
atol = threshold
rtol = 1e-8
elif isinstance(threshold, list) or isinstance(
threshold, tuple
):
atol = threshold[0]
rtol = threshold[1]
else:
raise NotImplementedError

is_fp8 = (
pred_config.tensorrt_precision_mode()
== paddle_infer.PrecisionType.Int8
)
trt_engine_num, paddle_op_num = nodes_num
self.assert_op_size(trt_engine_num, paddle_op_num)

# deserialize test
if trt_engine_num > 0:
self.run_test_config(
model,
params,
prog_config,
pred_config_deserialize,
feed_data,
if (not is_fp8 and quant) or (is_fp8 and not quant):
continue

ignore_flag = False
for teller, reason, note in self.ignore_cases:
if teller(prog_config, pred_config):
ignore_flag = True
if reason == IgnoreReasons.TRT_NOT_IMPLEMENTED:
self.ignore_log(
f"[TRT_NOT_IMPLEMENTED] {note} vs {self.inference_config_str(pred_config)}"
)
elif reason == IgnoreReasons.TRT_NOT_SUPPORT:
self.ignore_log(
f"[TRT_NOT_SUPPORT] {note} vs {self.inference_config_str(pred_config)}"
)
else:
raise NotImplementedError
break

if ignore_flag:
continue

try:
pred_config_deserialize = paddle_infer.Config(
pred_config
)
trt_result = self.run_test_config(
model, params, prog_config, pred_config, feed_data
)
self.assert_tensors_near(
atol, rtol, trt_result, baseline_result
)
trt_engine_num, paddle_op_num = nodes_num
self.assert_op_size(trt_engine_num, paddle_op_num)

# deserialize test
if trt_engine_num > 0:
self.run_test_config(
model,
params,
prog_config,
pred_config_deserialize,
feed_data,
)

self.success_log(f"program_config: {prog_config}")
self.success_log(
f"predictor_config: {self.inference_config_str(pred_config)}"
)
except Exception as e:
self.fail_log(f"program_config: {prog_config}")
self.fail_log(
f"predictor_config: {self.inference_config_str(pred_config)}"
)
self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m")
all_passes = False
self.success_log(f"program_config: {prog_config}")
self.success_log(
f"predictor_config: {self.inference_config_str(pred_config)}"
)
except Exception as e:
self.fail_log(f"program_config: {prog_config}")
self.fail_log(
f"predictor_config: {self.inference_config_str(pred_config)}"
)
self.fail_log(f"\033[1;31m ERROR INFO: {e}\033[0m")
all_passes = False

self.assertTrue(all_passes)

Expand Down
19 changes: 17 additions & 2 deletions test/ir/inference/program_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __init__(
if data_gen is not None:
self.data_gen = data_gen
self.data = data_gen()
self.dtype = data_gen().dtype
self.shape = data_gen().shape
self.dtype = self.data.dtype
self.shape = self.data.shape
else:
assert (
shape is not None
Expand All @@ -67,6 +67,11 @@ def __init__(
def __repr__(self):
return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype})

def astype(self, type: np.dtype):
self.data = self.data.astype(type)
self.dtype = self.data.dtype
return self


class VarType(enum.Enum):
LOD_TENSOR = 1
Expand Down Expand Up @@ -270,6 +275,16 @@ def __repr__(self):

return log_str

def set_input_type(self, type: np.dtype):
for inp in self.inputs.values():
inp.astype(type)
for weight in self.weights.values():
weight.astype(type)
return self

def get_input_type(self) -> np.dtype:
return next(iter(self.inputs.values())).dtype


def create_fake_model(program_config):
'''Create a Paddle model(in memory) according to the given config.'''
Expand Down
4 changes: 2 additions & 2 deletions test/ir/inference/test_trt_convert_bitwise_not.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ def generate_trt_nodes_num(attrs, dynamic_shape):
ver = paddle_infer.get_trt_compile_version()
trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10
if trt_version >= 8400:
if self.dims == 1 and not dynamic_shape:
if self.dims == 1:
return 0, 3
return 1, 2
else:
if (self.dims == 1 and not dynamic_shape) or (
if self.dims <= 2 or (
program_config.inputs['input_data'].dtype
in ['bool', 'int8', 'uint8']
):
Expand Down

0 comments on commit ab770f8

Please sign in to comment.