Skip to content

Commit

Permalink
[Paddle-TRT] Enforce use new executor for trt engine memory sharing (#…
Browse files Browse the repository at this point in the history
…59495)

* enforce use new executor for trt engine memory sharing

* update

* add ut

* fix bug
  • Loading branch information
yuanlehome authored Nov 30, 2023
1 parent 5f9f2ec commit 66a9199
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
7 changes: 7 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,13 @@ AnalysisPredictor::AnalysisPredictor(const AnalysisConfig &config)
"is enabled in Paddle-TRT, we set the id of these predictors to "
"negative sharing_identifier you specified : "
<< predictor_id_;
PADDLE_ENFORCE_EQ(
config_.new_executor_enabled(),
true,
platform::errors::InvalidArgument(
"Please call the config.enable_new_executor() in python or "
"config.EnableNewExecutor() in c++ when you want share the engine "
"context memory of multiple predictors."));
} else {
predictor_id_ = inference::GetUniqueId();
}
Expand Down
32 changes: 24 additions & 8 deletions paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,14 +835,30 @@ class TensorRTEngineOp : public framework::OperatorBase {
params.calibrator = calibrator_.get();
params.device_id = dev_place.device;
params.with_dynamic_shape = with_dynamic_shape_;
params.context_memory_sharing = Attr<bool>("context_memory_sharing");
params.use_dla = Attr<bool>("use_dla");
params.dla_core = Attr<int>("dla_core");
params.disable_trt_plugin_fp16 = Attr<bool>("disable_trt_plugin_fp16");
params.enable_low_precision_io = Attr<bool>("enable_low_precision_io");
params.use_inspector = Attr<bool>("use_inspector");
params.engine_info_path = Attr<std::string>("engine_info_path");

if (HasAttr("context_memory_sharing")) {
params.context_memory_sharing = Attr<bool>("context_memory_sharing");
}
if (HasAttr("use_dla")) {
params.use_dla = Attr<bool>("use_dla");
}
if (HasAttr("dla_core")) {
params.dla_core = Attr<int>("dla_core");
}
if (HasAttr("disable_trt_plugin_fp16")) {
params.disable_trt_plugin_fp16 = Attr<bool>("disable_trt_plugin_fp16");
}
if (HasAttr("enable_low_precision_io")) {
params.enable_low_precision_io = Attr<bool>("enable_low_precision_io");
}
if (HasAttr("use_inspector")) {
params.use_inspector = Attr<bool>("use_inspector");
}
if (HasAttr("engine_info_path")) {
params.engine_info_path = Attr<std::string>("engine_info_path");
}
if (HasAttr("optimization_level")) {
params.optimization_level = Attr<int>("optimization_level");
}
if (!shape_range_info_path_.empty()) {
inference::DeserializeShapeRangeInfo(shape_range_info_path_,
&params.min_input_shape,
Expand Down
6 changes: 4 additions & 2 deletions test/ir/inference/test_trt_inference_fp16_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ def init_predictor(self, low_precision_io: bool):
use_static=False,
use_calib_mode=False,
)
config.enable_tensorrt_memory_optim(True, 1)
config.enable_tuned_tensorrt_dynamic_shape()
config.enable_memory_optim()
config.enable_new_executor()
config.enable_low_precision_io(low_precision_io)
config.disable_glog_info()
predictor = create_predictor(config)
Expand All @@ -131,8 +132,9 @@ def init_predictor(self, low_precision_io: bool):
use_static=False,
use_calib_mode=False,
)
config.enable_tensorrt_memory_optim(True, 1)
config.enable_tuned_tensorrt_dynamic_shape()
config.enable_memory_optim()
config.enable_new_executor()
config.enable_low_precision_io(low_precision_io)
config.exp_disable_tensorrt_ops(["flatten_contiguous_range"])
config.disable_glog_info()
Expand Down

0 comments on commit 66a9199

Please sign in to comment.