Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix explicit quantization conv2d error #58015

Merged
merged 2 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 35 additions & 43 deletions paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,45 +43,16 @@ void ConvertConv2d(TensorRTEngine* engine,
framework::OpDesc op_desc(op, nullptr);

auto* X = engine->GetITensor(op_desc.Input("Input").front());
bool enable_int8 = op_desc.HasAttr("enable_int8");

if (enable_int8) {
#if IS_TRT_VERSION_GE(5000)
float in_scale = PADDLE_GET_CONST(float, op_desc.GetAttr("Input_scale"));
engine->SetTensorDynamicRange(X, in_scale);
#endif
}

std::string filter_var_name = op_desc.Input("Filter").front();
auto* Y_v = scope.FindVar(filter_var_name);
phi::DenseTensor* Y_t = nullptr;
nvinfer1::ITensor* filter = nullptr;
int n_output;
int n_input;
int filter_h;
int filter_w;
std::string filter_var_name = op_desc.Input("Filter").front();
TensorRTEngine::Weight weight;
if (engine->use_explicit_quantization()) {
auto* filter = engine->GetITensor(filter_var_name);
PADDLE_ENFORCE_NOT_NULL(
filter,
platform::errors::NotFound("Can not find %s ITensor in engine",
filter_var_name));
auto filter_dims = filter->getDimensions();
PADDLE_ENFORCE_EQ(
filter_dims.nbDims,
4UL,
platform::errors::InvalidArgument(
"The conv2d filter's dims size should be 4, but got %d",
filter_dims.nbDims));
n_output = filter_dims.d[0];
n_input = filter_dims.d[1];
filter_h = filter_dims.d[2];
filter_w = filter_dims.d[3];
} else {
auto* Y_v = scope.FindVar(filter_var_name);
PADDLE_ENFORCE_NOT_NULL(
Y_v,
platform::errors::NotFound("Can not find %s presistale var in scope.",
filter_var_name));
auto* Y_t = Y_v->GetMutable<phi::DenseTensor>();
if (Y_v) {
Y_t = Y_v->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
Y_t->dims().size(),
4UL,
Expand All @@ -92,7 +63,27 @@ void ConvertConv2d(TensorRTEngine* engine,
n_input = Y_t->dims()[1];
filter_h = Y_t->dims()[2];
filter_w = Y_t->dims()[3];
weight = engine->GetTrtWeight(op_desc.Input("Filter").front(), *Y_t);
} else {
filter = engine->GetITensor(op_desc.Input("Filter").front());
PADDLE_ENFORCE_EQ(
filter->getDimensions().nbDims,
4UL,
platform::errors::InvalidArgument(
"The conv2d filter's dims size should be 4, but got %d",
filter->getDimensions().nbDims));
n_output = filter->getDimensions().d[0];
n_input = filter->getDimensions().d[1];
filter_h = filter->getDimensions().d[2];
filter_w = filter->getDimensions().d[3];
}

bool enable_int8 = op_desc.HasAttr("enable_int8");

if (enable_int8) {
#if IS_TRT_VERSION_GE(5000)
float in_scale = PADDLE_GET_CONST(float, op_desc.GetAttr("Input_scale"));
engine->SetTensorDynamicRange(X, in_scale);
#endif
}
const int groups = PADDLE_GET_CONST(int, op_desc.GetAttr("groups"));
const std::vector<int> dilations =
Expand Down Expand Up @@ -133,7 +124,10 @@ void ConvertConv2d(TensorRTEngine* engine,
nv_post_paddings.d[0] = paddings[1];
nv_post_paddings.d[1] = paddings[3];
}

TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, nullptr, 0);
if (Y_v) {
weight = engine->GetTrtWeight(op_desc.Input("Filter").front(), *Y_t);
}
TensorRTEngine::Weight bias;
bias.SetDataType(weight.get().type);
bias.SetCount(0);
Expand Down Expand Up @@ -167,7 +161,10 @@ void ConvertConv2d(TensorRTEngine* engine,
layer->setStrideNd(nv_strides);

layer->setPrePadding(nv_pre_paddings);
if (output_padding.size() > 0) {

if (!Y_v) layer->setInput(1, *filter);

if (!output_padding.empty()) {
nv_post_paddings.d[0] -= output_padding[0];
nv_post_paddings.d[1] -= output_padding[1];
}
Expand All @@ -186,11 +183,6 @@ void ConvertConv2d(TensorRTEngine* engine,
// set dilations
fset_dilation(layer, nv_dilations);

if (engine->use_explicit_quantization()) {
auto* filter_tensor = engine->GetITensor(op_desc.Input("Filter").front());
layer->setInput(1, *filter_tensor);
}

auto output_name = op_desc.Output("Output").front();
layer->setName((name + " (Output: " + output_name + ")").c_str());
layer->getOutput(0)->setName(output_name.c_str());
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,12 @@ struct SimpleOpTypeSetTeller : public Teller {
auto* block = desc.Block();
if (block) {
auto* filter_var_desc = block->FindVar(desc.Input("Filter")[0]);
if (!filter_var_desc->Persistable() && !use_explicit_quantization) {
if (!filter_var_desc->Persistable()) {
#if IS_TRT_VERSION_GE(8600)
#else
LOG(INFO)
<< "Trt below 8.6 not support conv2d's filter is a intermedoate "
"tensor in conv2d op, please upgarde your TenroRT.";
"tensor in conv2d op, please upgarde your TensorRT.";
return false;
#endif
}
Expand Down
10 changes: 6 additions & 4 deletions test/ir/inference/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,12 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_trt_inference_predictor PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_inference_fp16_io PROPERTIES TIMEOUT 300)
set_tests_properties(test_trt_optimization_level PROPERTIES TIMEOUT 300)
set_tests_properties(test_trt_explicit_quantization_resnet PROPERTIES TIMEOUT
300)
set_tests_properties(test_trt_explicit_quantization_mobilenet
PROPERTIES TIMEOUT 300)
if(NOT WIN32)
set_tests_properties(test_trt_explicit_quantization_resnet
PROPERTIES TIMEOUT 300)
set_tests_properties(test_trt_explicit_quantization_mobilenet
PROPERTIES TIMEOUT 300)
endif()
if(WITH_MKLDNN)
set_tests_properties(test_save_optimized_model_pass PROPERTIES TIMEOUT 300)
endif()
Expand Down