-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tensorrt convert init #10144
tensorrt convert init #10144
Changes from all commits
42febfa
48473dd
d599de5
c4e3010
326221a
6f6f330
9945265
beb1245
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
if(WITH_TESTING) | ||
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) | ||
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) | ||
endif() | ||
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) | ||
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) | ||
set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc) | ||
add_subdirectory(convert) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
nv_test(test_tensorrt_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES}) | ||
nv_test(test_tensorrt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc | ||
DEPS ${FLUID_CORE_MODULES} activation_op) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" | ||
|
||
namespace paddle { | ||
namespace inference { | ||
namespace tensorrt { | ||
|
||
class ReluOpConverter : public OpConverter { | ||
public: | ||
ReluOpConverter() {} | ||
void operator()(const framework::OpDesc& op) override { | ||
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose " | ||
"type is Relu"; | ||
const nvinfer1::ITensor* input_tensor = | ||
engine_->GetITensor(op.Input("X")[0]); | ||
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( | ||
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor), | ||
nvinfer1::ActivationType::kRELU); | ||
engine_->SetITensor(op.Output("Out")[0], layer->getOutput(0)); | ||
} | ||
}; | ||
|
||
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter); | ||
|
||
} // namespace tensorrt | ||
} // namespace inference | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" | ||
|
||
namespace paddle { | ||
namespace inference { | ||
namespace tensorrt { | ||
|
||
class Conv2dOpConverter : public OpConverter { | ||
public: | ||
Conv2dOpConverter() {} | ||
void operator()(const framework::OpDesc& op) override { | ||
LOG(INFO) | ||
<< "convert a fluid conv2d op to tensorrt conv layer without bias"; | ||
} | ||
}; | ||
|
||
REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter); | ||
|
||
} // namespace tensorrt | ||
} // namespace inference | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" | ||
|
||
namespace paddle { | ||
namespace inference { | ||
namespace tensorrt { | ||
|
||
class MulOpConverter : public OpConverter { | ||
public: | ||
MulOpConverter() {} | ||
void operator()(const framework::OpDesc& op) override { | ||
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias"; | ||
} | ||
}; | ||
|
||
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter); | ||
|
||
} // namespace tensorrt | ||
} // namespace inference | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include <unordered_map> | ||
#include "paddle/fluid/framework/block_desc.h" | ||
#include "paddle/fluid/framework/scope.h" | ||
#include "paddle/fluid/inference/tensorrt/engine.h" | ||
|
||
namespace paddle { | ||
namespace inference { | ||
namespace tensorrt { | ||
|
||
/* | ||
* Convert Op from Fluid to TensorRT Engine. | ||
*/ | ||
class OpConverter { | ||
public: | ||
OpConverter() {} | ||
virtual void operator()(const framework::OpDesc& op) {} | ||
|
||
void Execute(const framework::OpDesc& op, TensorRTEngine* engine) { | ||
std::string type = op.Type(); | ||
auto it = converters_.find(type); | ||
PADDLE_ENFORCE(it != converters_.end(), "no OpConverter for optype [%s]", | ||
type); | ||
it->second->SetEngine(engine); | ||
(*it->second)(op); | ||
} | ||
|
||
static OpConverter& Global() { | ||
static auto* x = new OpConverter; | ||
return *x; | ||
} | ||
|
||
template <typename T> | ||
void Register(const std::string& key) { | ||
converters_[key] = new T; | ||
} | ||
|
||
// convert fluid op to tensorrt layer | ||
void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) { | ||
OpConverter::Global().Execute(op, engine); | ||
} | ||
|
||
// convert fluid block to tensorrt network | ||
void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) { | ||
for (auto op : block.AllOps()) { | ||
OpConverter::Global().Execute(*op, engine); | ||
} | ||
} | ||
|
||
void SetEngine(TensorRTEngine* engine) { engine_ = engine; } | ||
|
||
virtual ~OpConverter() {} | ||
|
||
// TensorRT engine | ||
TensorRTEngine* engine_{nullptr}; | ||
|
||
private: | ||
// registered op converter map, whose key is the fluid op type, and value is | ||
// the pointer position of corresponding OpConverter class. | ||
std::unordered_map<std::string, OpConverter*> converters_; | ||
// fluid inference scope | ||
framework::Scope* scope_{nullptr}; | ||
}; | ||
|
||
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \ | ||
struct trt_##op_type__##_converter { \ | ||
trt_##op_type__##_converter() { \ | ||
OpConverter::Global().Register<Converter__>(#op_type__); \ | ||
} \ | ||
}; \ | ||
trt_##op_type__##_converter trt_##op_type__##_converter__; | ||
|
||
} // namespace tensorrt | ||
} // namespace inference | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <gtest/gtest.h> | ||
#include "paddle/fluid/framework/lod_tensor.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/program_desc.h" | ||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" | ||
#include "paddle/fluid/platform/device_context.h" | ||
#include "paddle/fluid/platform/place.h" | ||
|
||
USE_OP(relu); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this to bottom of the file, for that reader needn't care this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we move it to the bottom of the file, it seems no effect, and unit-test is a failure. |
||
|
||
namespace paddle { | ||
namespace inference { | ||
namespace tensorrt { | ||
|
||
void compare(float input, float expect) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Compare There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
framework::Scope scope; | ||
platform::CUDAPlace place; | ||
platform::CUDADeviceContext ctx(place); | ||
|
||
// init fluid op and variable | ||
auto x_var = scope.Var("X"); | ||
auto x_tensor = x_var->GetMutable<framework::LoDTensor>(); | ||
x_tensor->Resize({1, 1}); | ||
std::vector<float> init; | ||
init.push_back(input); | ||
framework::TensorFromVector(init, ctx, x_tensor); | ||
|
||
auto out_var = scope.Var("Out"); | ||
auto out_tensor = out_var->GetMutable<framework::LoDTensor>(); | ||
out_tensor->Resize({1, 1}); | ||
out_tensor->mutable_data<float>(place); | ||
|
||
framework::OpDesc op_desc; | ||
op_desc.SetType("relu"); | ||
op_desc.SetInput("X", {"X"}); | ||
op_desc.SetOutput("Out", {"Out"}); | ||
|
||
auto relu_op = framework::OpRegistry::CreateOp(op_desc); | ||
|
||
// run fluid op | ||
relu_op->Run(scope, place); | ||
std::vector<float> out1; | ||
framework::TensorToVector(*out_tensor, ctx, &out1); | ||
|
||
// init tensorrt op | ||
cudaStream_t stream; | ||
ASSERT_EQ(0, cudaStreamCreate(&stream)); | ||
TensorRTEngine* engine = new TensorRTEngine(1, 1 << 10, &stream); | ||
engine->InitNetwork(); | ||
engine->DeclareInput("X", nvinfer1::DataType::kFLOAT, | ||
nvinfer1::DimsCHW{1, 1, 1}); | ||
|
||
OpConverter op_converter; | ||
op_converter.ConvertOp(op_desc, engine); | ||
|
||
engine->DeclareOutput("Out"); | ||
engine->FreezeNetwork(); | ||
engine->SetInputFromCPU("X", &input, 1 * sizeof(float)); | ||
|
||
// run tensorrt op | ||
engine->Execute(1); | ||
|
||
float out2; | ||
engine->GetOutputInCPU("Out", &out2, 1 * sizeof(float)); | ||
|
||
ASSERT_EQ(out1[0], out2); | ||
ASSERT_EQ(out1[0], expect); | ||
|
||
delete engine; | ||
cudaStreamDestroy(stream); | ||
} | ||
|
||
TEST(OpConverter, ConvertRelu) { | ||
compare(1, 1); // relu(1) = 1 | ||
compare(-5, 0); // relu(-5) = 0 | ||
} | ||
|
||
} // namespace tensorrt | ||
} // namespace inference | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <gtest/gtest.h> | ||
#include "paddle/fluid/framework/program_desc.h" | ||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" | ||
|
||
namespace paddle { | ||
namespace inference { | ||
namespace tensorrt { | ||
|
||
TEST(BlockConverter, ConvertBlock) { | ||
framework::ProgramDesc prog; | ||
auto* block = prog.MutableBlock(0); | ||
auto* mul_op = block->AppendOp(); | ||
mul_op->SetType("mul"); | ||
auto* conv2d_op = block->AppendOp(); | ||
conv2d_op->SetType("conv2d"); | ||
|
||
OpConverter converter; | ||
converter.ConvertBlock(*block, nullptr /*TensorRTEngine*/); | ||
} | ||
|
||
} // namespace tensorrt | ||
} // namespace inference | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以参考我pr里的写法,把单例和接口拆开。 之前王叔说过这个三种设计模式混在一起。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be refactored latter when utils/singleton.h in the convert_io PR merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done