Skip to content

Commit

Permalink
[Paddle-Inference] Support trt 0dims of expand_as_v2 and mish. (Paddl…
Browse files Browse the repository at this point in the history
…ePaddle#53627)

* support_expand_mish
  • Loading branch information
xiaoxiaohehe001 authored and zhangjun committed May 12, 2023
1 parent 49d31d6 commit 3cb5e8a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
15 changes: 13 additions & 2 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1873,8 +1873,10 @@ struct SimpleOpTypeSetTeller : public Teller {
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "mish op does not support input's dim is 1 in tensorrt.";
if ((!with_dynamic_shape && x_shape.size() == 1) || x_shape.size() == 0) {
VLOG(3) << op_type
<< "mish op does not support input's dim is 1 in tensorrt "
"static shape mode or 0.";
return false;
}
}
Expand Down Expand Up @@ -2612,6 +2614,15 @@ struct SimpleOpTypeSetTeller : public Teller {
"the pass.";
return false;
}

#if IS_TRT_VERSION_LT(8000)
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 0) {
return false; // not supported 0 dim.
}
#endif
}

if (op_type == "grid_sampler") {
Expand Down
13 changes: 11 additions & 2 deletions test/ir/inference/test_trt_convert_expand_as_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ def generate_input1(attrs: List[Dict[str, Any]]):
elif self.dims == 1:
self.input_shape = [32]
return np.random.random([32]).astype(np.float32)
elif self.dims == 0:
self.input_shape = []
return np.random.random([]).astype(np.float32)

for dims in [1, 2, 3, 4]:
for dims in [0, 1, 2, 3, 4]:
for shape in [
[10, 8, 32, 32],
[2, 8, 32, 32],
Expand Down Expand Up @@ -125,14 +128,20 @@ def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"expand_v2_input": [32]}
self.dynamic_shape.max_input_shape = {"expand_v2_input": [64]}
self.dynamic_shape.opt_input_shape = {"expand_v2_input": [32]}
elif self.dims == 0:
self.dynamic_shape.min_input_shape = {"expand_v2_input": []}
self.dynamic_shape.max_input_shape = {"expand_v2_input": []}
self.dynamic_shape.opt_input_shape = {"expand_v2_input": []}

def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}

def generate_trt_nodes_num(attrs, dynamic_shape):
if dynamic_shape:
ver = paddle_infer.get_trt_compile_version()
ver_num = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10
if dynamic_shape and (ver_num > 8000 or self.dims > 0):
return 1, 2
else:
return 0, 3
Expand Down

0 comments on commit 3cb5e8a

Please sign in to comment.