Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add trt supoort for slice op #41467

Merged
merged 7 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions paddle/fluid/inference/tensorrt/convert/slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class SliceOpConverter : public OpConverter {
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("starts"));
std::vector<int> ends =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("ends"));
std::vector<int> decrease_axises =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("decrease_axis"));

auto input_dims = input->getDimensions();
if (!engine_->with_dynamic_shape()) {
Expand Down Expand Up @@ -107,8 +109,10 @@ class SliceOpConverter : public OpConverter {
} else {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::SlicePluginDynamic* plugin =
new plugin::SlicePluginDynamic(starts, ends, axes, with_fp16);
int decrease_axis =
decrease_axises.size() == 0 ? -1 : decrease_axises[0];
plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic(
starts, ends, axes, decrease_axis, with_fp16);
layer = engine_->AddDynamicPlugin(&input, 1, plugin);
}
} else {
Expand Down
29 changes: 17 additions & 12 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -930,10 +930,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (desc.HasAttr("decrease_axis")) {
std::vector<int> decrease_axis =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("decrease_axis"));
if (decrease_axis.size() > 0) {
VLOG(3) << "Invalid slice decrease_axis. decrease_axis.size() > 0"
"is not supported in TensorRT";
return false;
if (with_dynamic_shape) {
if (decrease_axis.size() > 1) {
return false;
}
} else {
if (decrease_axis.size() > 0) {
VLOG(3) << "Invalid slice decrease_axis. decrease_axis.size() > 0"
"is not supported in TensorRT";
return false;
}
}
}

Expand Down Expand Up @@ -1046,17 +1052,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
}
if (desc.Input("Ids").size() != desc.Input("Embs").size()) {
VLOG(3) << "The id and emb size of fused EmbEltwiseLayerNormOp "
"should be same ";
return false;
}
}

if (op_type == "fused_preln_embedding_eltwise_layernorm") {
if (!with_dynamic_shape) {
VLOG(3)
<< "fused_preln_embedding_eltwise_layernorm should run on dynamic "
"shape mode.";
VLOG(3) << "fused_preln_embedding_eltwise_layernorm should run on "
"dynamic "
"shape mode.";
return false;
}
if (desc.Input("Ids").size() != desc.Input("Embs").size()) {
Expand Down Expand Up @@ -1446,7 +1450,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
const auto y_shape = y_var_desc->GetShape();
if (y_shape.size() != 2) {
VLOG(3)
<< " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes = "
<< " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes =
"
<< y_shape.size();
return false;
}
Expand Down Expand Up @@ -1590,8 +1595,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
#else
if (dtype != framework::proto::VarType::FP32) {
VLOG(3)
<< "reduce op input data type must be float32 using TensorRT < 7.0";
VLOG(3) << "reduce op input data type must be float32 using TensorRT "
"< 7.0";
return false;
}
#endif
Expand Down
25 changes: 22 additions & 3 deletions paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
#if IS_TRT_VERSION_GE(6000)
SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts,
std::vector<int> ends,
std::vector<int> axes, bool with_fp16)
: starts_(starts), ends_(ends), axes_(axes) {
std::vector<int> axes, int decrease_axis,
bool with_fp16)
: starts_(starts), ends_(ends), axes_(axes), decrease_axis_(decrease_axis) {
with_fp16_ = with_fp16;
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
Expand All @@ -217,6 +218,7 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
DeserializeValue(&serialData, &serialLength, &starts_);
DeserializeValue(&serialData, &serialLength, &ends_);
DeserializeValue(&serialData, &serialLength, &axes_);
DeserializeValue(&serialData, &serialLength, &decrease_axis_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
Expand All @@ -233,7 +235,8 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }

size_t SlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
size_t size = SerializedSize(starts_) + SerializedSize(ends_) +
SerializedSize(axes_) + SerializedSize(with_fp16_);
SerializedSize(axes_) + SerializedSize(decrease_axis_) +
SerializedSize(with_fp16_);

return size;
}
Expand All @@ -242,6 +245,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, starts_);
SerializeValue(&buffer, ends_);
SerializeValue(&buffer, axes_);
SerializeValue(&buffer, decrease_axis_);
SerializeValue(&buffer, with_fp16_);
}

Expand All @@ -265,6 +269,17 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
ret.d[axes_[i]] = expr_builder.constant(end - start);
#endif
}
if (decrease_axis_ != -1) {
nvinfer1::DimsExprs res;
res.nbDims = ret.nbDims - 1;
int j = 0;
for (size_t i = 0; i < in_dims.nbDims; i++) {
if (decrease_axis_ == i) continue;
res.d[j++] = expr_builder.operation(nvinfer1::DimensionOperation::kMAX,
*expr_builder.constant(0), *ret.d[i]);
}
return res;
}
return ret;
}

Expand Down Expand Up @@ -318,6 +333,10 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims;
auto out_dims = output_desc[0].dims;
if (decrease_axis_ != -1) {
out_dims = input_dims;
out_dims.d[decrease_axis_] = 1;
}
auto num_dims = input_dims.nbDims;
size_t out_num = ProductDim(out_dims);

Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@ REGISTER_TRT_PLUGIN_V2(SlicePluginCreator);
class SlicePluginDynamic : public DynamicPluginTensorRT {
public:
explicit SlicePluginDynamic(std::vector<int> starts, std::vector<int> ends,
std::vector<int> axes, bool with_fp16);
std::vector<int> axes, int decrease_axis,
bool with_fp16);

nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new SlicePluginDynamic(starts_, ends_, axes_, with_fp16_);
return new SlicePluginDynamic(starts_, ends_, axes_, decrease_axis_,
with_fp16_);
}

SlicePluginDynamic(void const* serialData, size_t serialLength);
Expand Down Expand Up @@ -140,6 +142,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT {
std::vector<int> starts_;
std::vector<int> ends_;
std::vector<int> axes_;
int decrease_axis_;
int* offset_temp_data_{nullptr};
cudaEvent_t copy_event_;
cudaStream_t copy_stream_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def is_program_valid(self, program_config: ProgramConfig) -> bool:

def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]]):
return np.ones([1, 3, 64, 64]).astype(np.float32)
return np.ones([6, 6, 64, 64]).astype(np.float32)

for axes in [[0, 1], [1, 3], [2, 3]]:
for starts in [[0, 1], [-4, -3]]:
for ends in [[2, 2], [-1, -2], [5, 5]]:
for starts in [[0, 1]]:
for ends in [[2, 2], [5, 5]]:
for decrease_axis in [[], [1], [2], [-1], [-100]]:
for infer_flags in [[-1]]:
dics = [{
Expand Down Expand Up @@ -97,8 +97,8 @@ def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]}
self.dynamic_shape.max_input_shape = {"input_data": [8, 8, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [6, 6, 64, 64]}

def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
Expand All @@ -107,7 +107,11 @@ def clear_dynamic_shape():

def generate_trt_nodes_num(attrs, dynamic_shape):
inputs = program_config.inputs
if len(attrs[0]["decrease_axis"]) != 0:
if dynamic_shape == True and len(attrs[0]["decrease_axis"]) == 0:
return 1, 2
if dynamic_shape == True and len(attrs[0]["decrease_axis"]) != 1:
return 0, 3
if dynamic_shape == False and len(attrs[0]["decrease_axis"]) != 0:
return 0, 3
if dynamic_shape:
for i in range(len(attrs[0]["starts"])):
Expand All @@ -123,7 +127,7 @@ def generate_trt_nodes_num(attrs, dynamic_shape):
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]

self.trt_param.max_batch_size = 9
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
Expand All @@ -146,7 +150,7 @@ def test(self):
# TODO(inference): fix.
# trt6 and trt7.1 has bug.
# trt7.2 deserialize has bug.
# self.run_test()
self.run_test()
pass


Expand Down