Skip to content

Commit

Permalink
scale, square, sum, swish trt op converter support zero dim (#53660)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome authored May 10, 2023
1 parent 65e57a7 commit 6a279df
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 72 deletions.
62 changes: 41 additions & 21 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"erf", "floor", "round",
"sign", "silu", "logical_not",
"reciprocal", "tanh_shrink", "logsigmoid",
"rsqrt"};
"rsqrt", "swish"};
std::unordered_set<std::string> unary_list = {
"exp", "log", "sqrt", "abs", "sin",
"cos", "tan", "tanh", "sinh", "cosh",
Expand Down Expand Up @@ -1194,9 +1194,9 @@ struct SimpleOpTypeSetTeller : public Teller {
dtype == framework::proto::VarType::FP16)) {
return false;
}
if (x_shape.size() == 1) {
VLOG(3)
<< "Scale op does not support 1-dimensional input in tensorrt";
if (x_shape.size() == 1 || x_shape.size() == 0) {
VLOG(3) << "Scale op does not support 0 or 1-dimensional input in "
"tensorrt";
return false;
}
} else {
Expand Down Expand Up @@ -1548,8 +1548,24 @@ struct SimpleOpTypeSetTeller : public Teller {
return false;
}
}
// remember that 1D input in static shape mode is filtered at the beginning

if (op_type == "sum") {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var = block->FindVar(x_var_name);
const auto x_shape = x_var->GetShape();
if (!with_dynamic_shape && (x_shape.size() == 0 || x_shape.size() == 1)) {
VLOG(3) << op_type
<< " op does not support input's dim is 0 or 1 in tensorrt "
"with static shape.";
return false;
}
return true;
}

Expand Down Expand Up @@ -1803,22 +1819,7 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
}
if (op_type == "swish") {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "swish op does not support input's dim is 1 in tensorrt.";
return false;
}
}

if (op_type == "prelu") {
if (desc.Input("X").size() != 1) {
VLOG(3) << "Invalid input X's size of prelu TRT converter. "
Expand Down Expand Up @@ -2180,6 +2181,25 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}

if (op_type == "square") {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var = block->FindVar(x_var_name);
const auto x_shape = x_var->GetShape();
if (!with_dynamic_shape && x_shape.size() == 0) {
VLOG(3) << op_type
<< " op does not support input's dim is 0 in tensorrt "
"with static shape.";
return false;
}
}

if (op_type == "clip") {
// Paddle-TRT does not support the input tensors: Min and Max
auto clip_inputs = desc.Inputs();
Expand Down
23 changes: 10 additions & 13 deletions test/ir/inference/test_trt_convert_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ def generate_input1(attrs: List[Dict[str, Any]], batch, is_int):
)
elif self.dims == 1:
return np.ones([24]).astype(np.int32 if is_int else np.float32)
elif self.dims == 0:
return np.ones([]).astype(np.int32 if is_int else np.float32)

def generate_weight1(attrs: List[Dict[str, Any]], is_int):
return np.ones([1]).astype(np.int32 if is_int else np.float32)

for num_input in [0, 1]:
for dims in [1, 2, 3, 4]:
for dims in [0, 1, 2, 3, 4]:
for batch in [1, 2]:
for scale in [0.1, -1.0]:
for bias in [0.0, 1.2]:
Expand Down Expand Up @@ -141,13 +143,19 @@ def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"scale_input": [24]}
self.dynamic_shape.max_input_shape = {"scale_input": [48]}
self.dynamic_shape.opt_input_shape = {"scale_input": [24]}
elif self.dims == 0:
self.dynamic_shape.min_input_shape = {"scale_input": []}
self.dynamic_shape.max_input_shape = {"scale_input": []}
self.dynamic_shape.opt_input_shape = {"scale_input": []}

def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}

def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape and (self.dims == 1 or self.dims == 0):
return 0, 3
return 1, 2

attrs = [
Expand Down Expand Up @@ -189,23 +197,12 @@ def teller1(program_config, predictor_config):
)

def teller2(program_config, predictor_config):
if self.dims == 1 and len(self.dynamic_shape.min_input_shape) == 0:
return True
return False

self.add_skip_case(
teller2,
SkipReasons.TRT_NOT_SUPPORT,
"INPUT DIM EQUAL TO 1 OF STATIC SHAPE NOT SUPPORT",
)

def teller3(program_config, predictor_config):
if self.is_int and len(self.dynamic_shape.min_input_shape) == 0:
return True
return False

self.add_skip_case(
teller3,
teller2,
SkipReasons.TRT_NOT_SUPPORT,
"INTEGER INPUT OF STATIC SHAPE NOT SUPPORT",
)
Expand Down
64 changes: 34 additions & 30 deletions test/ir/inference/test_trt_convert_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def is_program_valid(self, program_config: ProgramConfig) -> bool:

def sample_program_configs(self):
def generate_input1(dims):
if dims == 1:
if dims == 0:
return np.ones([]).astype(np.float32)
elif dims == 1:
return np.ones([3]).astype(np.float32)
elif dims == 2:
return np.ones([3, 64]).astype(np.float32)
Expand All @@ -38,40 +40,42 @@ def generate_input1(dims):
else:
return np.ones([1, 3, 64, 64]).astype(np.float32)

for dims in [1, 2, 3, 4]:
for alpha in [1.0, 2.0, 3.0]:
self.dims = dims

ops_config = [
{
"op_type": "square",
"op_inputs": {
"X": ["input_data"],
},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": {},
}
]
ops = self.generate_op_config(ops_config)

program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input1, dims)
)
for dims in [0, 1, 2, 3, 4]:
self.dims = dims
ops_config = [
{
"op_type": "square",
"op_inputs": {
"X": ["input_data"],
},
outputs=["output_data"],
)

yield program_config
"op_outputs": {"Out": ["output_data"]},
"op_attrs": {},
}
]
ops = self.generate_op_config(ops_config)

program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input1, dims)
)
},
outputs=["output_data"],
)

yield program_config

def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 1:
if self.dims == 0:
self.dynamic_shape.min_input_shape = {"input_data": []}
self.dynamic_shape.max_input_shape = {"input_data": []}
self.dynamic_shape.opt_input_shape = {"input_data": []}
elif self.dims == 1:
self.dynamic_shape.min_input_shape = {"input_data": [1]}
self.dynamic_shape.max_input_shape = {"input_data": [128]}
self.dynamic_shape.opt_input_shape = {"input_data": [64]}
Expand Down Expand Up @@ -102,7 +106,7 @@ def clear_dynamic_shape():
self.dynamic_shape.opt_input_shape = {}

def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape and self.dims == 1:
if not dynamic_shape and (self.dims == 1 or self.dims == 0):
return 0, 3
return 1, 2

Expand Down
42 changes: 38 additions & 4 deletions test/ir/inference/test_trt_convert_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def generate_input1(batch):
return np.ones([batch, 24]).astype(np.float32)
elif self.dims == 1:
return np.ones([24]).astype(np.float32)
elif self.dims == 0:
return np.ones([]).astype(np.float32)

def generate_input2(batch):
if self.dims == 4:
Expand All @@ -47,6 +49,8 @@ def generate_input2(batch):
return np.ones([batch, 24]).astype(np.float32)
elif self.dims == 1:
return np.ones([24]).astype(np.float32)
elif self.dims == 0:
return np.ones([]).astype(np.float32)

def generate_input3(batch):
if self.dims == 4:
Expand All @@ -57,8 +61,10 @@ def generate_input3(batch):
return np.ones([batch, 24]).astype(np.float32)
elif self.dims == 1:
return np.ones([24]).astype(np.float32)
elif self.dims == 0:
return np.ones([]).astype(np.float32)

for dims in [1, 2, 3, 4]:
for dims in [0, 1, 2, 3, 4]:
for batch in [1, 4]:
self.dims = dims
ops_config = [
Expand Down Expand Up @@ -157,14 +163,30 @@ def generate_dynamic_shape():
"input2": [24],
"input3": [24],
}
elif self.dims == 0:
self.dynamic_shape.min_input_shape = {
"input1": [],
"input2": [],
"input3": [],
}
self.dynamic_shape.max_input_shape = {
"input1": [],
"input2": [],
"input3": [],
}
self.dynamic_shape.opt_input_shape = {
"input1": [],
"input2": [],
"input3": [],
}

def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}

def generate_trt_nodes_num(dynamic_shape):
if self.dims == 1 and not dynamic_shape:
if (self.dims == 1 or self.dims == 0) and not dynamic_shape:
return 0, 5
return 1, 4

Expand Down Expand Up @@ -205,8 +227,10 @@ def generate_input1(batch):
return np.ones([batch, 24]).astype(np.float32)
elif self.dims == 1:
return np.ones([24]).astype(np.float32)
else:
return np.ones([]).astype(np.float32)

for dims in [1, 2, 3, 4]:
for dims in [0, 1, 2, 3, 4]:
for batch in [1, 4]:
self.dims = dims
ops_config = [
Expand Down Expand Up @@ -263,14 +287,24 @@ def generate_dynamic_shape():
self.dynamic_shape.opt_input_shape = {
"input1": [24],
}
elif self.dims == 0:
self.dynamic_shape.min_input_shape = {
"input1": [],
}
self.dynamic_shape.max_input_shape = {
"input1": [],
}
self.dynamic_shape.opt_input_shape = {
"input1": [],
}

def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}

def generate_trt_nodes_num(dynamic_shape):
if self.dims == 1 and not dynamic_shape:
if (self.dims == 1 or self.dims == 0) and not dynamic_shape:
return 0, 3
return 1, 2

Expand Down
14 changes: 10 additions & 4 deletions test/ir/inference/test_trt_convert_swish.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def is_program_valid(self, program_config: ProgramConfig) -> bool:

def sample_program_configs(self):
def generate_input1(dims, attrs: List[Dict[str, Any]]):
if dims == 1:
if dims == 0:
return np.ones([]).astype(np.float32)
elif dims == 1:
return np.ones([3]).astype(np.float32)
elif dims == 2:
return np.ones([3, 64]).astype(np.float32)
Expand All @@ -38,7 +40,7 @@ def generate_input1(dims, attrs: List[Dict[str, Any]]):
else:
return np.ones([1, 3, 64, 64]).astype(np.float32)

for dims in [1, 2, 3, 4]:
for dims in [0, 1, 2, 3, 4]:
for beta in [1.0, 2.0, 3.0]:
self.dims = dims

Expand Down Expand Up @@ -73,7 +75,11 @@ def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
if self.dims == 1:
if self.dims == 0:
self.dynamic_shape.min_input_shape = {"input_data": []}
self.dynamic_shape.max_input_shape = {"input_data": []}
self.dynamic_shape.opt_input_shape = {"input_data": []}
elif self.dims == 1:
self.dynamic_shape.min_input_shape = {"input_data": [1]}
self.dynamic_shape.max_input_shape = {"input_data": [128]}
self.dynamic_shape.opt_input_shape = {"input_data": [64]}
Expand Down Expand Up @@ -104,7 +110,7 @@ def clear_dynamic_shape():
self.dynamic_shape.opt_input_shape = {}

def generate_trt_nodes_num(attrs, dynamic_shape):
if self.dims == 1:
if (self.dims == 1 or self.dims == 0) and not dynamic_shape:
return 0, 3
return 1, 2

Expand Down

0 comments on commit 6a279df

Please sign in to comment.