Skip to content

Commit

Permalink
reshape_opteller (#41090)
Browse files Browse the repository at this point in the history
fix_reshape: for paddle-trt
  • Loading branch information
xiaoxiaohehe001 authored Apr 1, 2022
1 parent f3270fc commit 15d5f6b
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1479,8 +1479,27 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
std::vector<int> shape =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("shape"));
if (shape.size() >= nvinfer1::Dims::MAX_DIMS) return false;
if (!with_dynamic_shape && (shape[0] == -1 || shape.size() == 1))
if (!with_dynamic_shape) {
if (shape.size() == 1) {
return false;
}
if (shape[0] == 0) {
return true;
} else {
auto* block = desc.Block();
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
int input_num = std::accumulate(x_shape.begin() + 1, x_shape.end(), 1,
std::multiplies<int>());
int shape_num = std::accumulate(shape.begin() + 1, shape.end(), 1,
std::multiplies<int>());
if (input_num == shape_num) {
return true;
}
}
return false;
}
}

if (op_type == "clip") {
Expand Down

0 comments on commit 15d5f6b

Please sign in to comment.