From f0fefaa9bed19dd1f9d0930c38e13f991011d083 Mon Sep 17 00:00:00 2001 From: Abhiram Iyer Date: Tue, 16 Jun 2020 12:45:25 -0700 Subject: [PATCH] fix(): trying to resolve interpolate plugin problems Signed-off-by: Abhiram Iyer Signed-off-by: Abhiram Iyer --- core/conversion/converters/BUILD | 2 +- .../converters/impl/interpolate.cpp | 33 +- core/conversion/converters/impl/plugins/BUILD | 6 +- .../impl/plugins/interpolate_plugin.cpp | 356 ++++++++---------- .../impl/plugins/interpolate_plugin.h | 128 +++++++ 5 files changed, 319 insertions(+), 206 deletions(-) create mode 100755 core/conversion/converters/impl/plugins/interpolate_plugin.h diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index ae3a3a27e7..56e693489c 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -10,7 +10,7 @@ config_setting( cc_library( name = "converters", hdrs = [ - "converters.h", + "converters.h" ], srcs = [ "NodeConverterRegistry.cpp", diff --git a/core/conversion/converters/impl/interpolate.cpp b/core/conversion/converters/impl/interpolate.cpp index 4aa710798d..c7f1df07b2 100755 --- a/core/conversion/converters/impl/interpolate.cpp +++ b/core/conversion/converters/impl/interpolate.cpp @@ -1,6 +1,8 @@ #include "torch/torch.h" #include "core/util/prelude.h" #include "core/conversion/converters/converters.h" +#include "NvInfer.h" +#include "plugins/interpolate_plugin.h" #include @@ -108,7 +110,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() auto in = args[0].ITensor(); auto in_shape = util::toVec(in->getDimensions()); - bool align_corners = args[2].IValue()->to(); + bool align_corners = args[2].unwrapToBool(); // Case 1: user uses output size and not scales if (!args[1].IValue()->isNone() && args[3].IValue()->isNone()) { @@ -119,16 +121,29 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); - auto resize_layer = ctx->net->addResize(*in); - TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n); + if (!align_corners) { + //auto creator = getPluginRegistry()->getPluginCreator("interpolate", "1"); + //auto* plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners); + auto creator = new plugins::InterpolatePluginCreator(); - resize_layer->setOutputDimensions(util::toDims(out_shape)); - resize_layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR); - resize_layer->setAlignCorners(align_corners); - resize_layer->setName(util::node_info(n).c_str()); + auto plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners); - auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0)); - LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions()); + auto resize_layer = ctx->net->addPluginV2(reinterpret_cast(in), 1, *plugin); + + auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions()); + } else { + auto resize_layer = ctx->net->addResize(*in); + TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n); + + resize_layer->setOutputDimensions(util::toDims(out_shape)); + resize_layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR); + resize_layer->setAlignCorners(align_corners); + resize_layer->setName(util::node_info(n).c_str()); + + auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions()); + } } else { TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_linear1d not supported yet."); } diff --git a/core/conversion/converters/impl/plugins/BUILD b/core/conversion/converters/impl/plugins/BUILD index 9e22caf425..a9a96c27e1 100755 --- a/core/conversion/converters/impl/plugins/BUILD +++ b/core/conversion/converters/impl/plugins/BUILD @@ -9,7 +9,9 @@ config_setting( cc_library( name = "plugins", - hdrs = [], + hdrs = [ + "interpolate_plugin.h" + ], srcs = [ "interpolate_plugin.cpp" ], @@ -29,5 +31,5 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar") pkg_tar( name = "include", package_dir = "core/conversion/converters/impl/plugins", - srcs = [], + srcs = ["interpolate_plugin.h"], ) diff --git a/core/conversion/converters/impl/plugins/interpolate_plugin.cpp b/core/conversion/converters/impl/plugins/interpolate_plugin.cpp index 5672f4f4f1..a38e9e8674 100755 --- a/core/conversion/converters/impl/plugins/interpolate_plugin.cpp +++ b/core/conversion/converters/impl/plugins/interpolate_plugin.cpp @@ -1,15 +1,17 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "core/util/prelude.h" -#include "torch/torch.h" -#include "NvInfer.h" +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include + +// #include "core/util/prelude.h" +// #include "torch/torch.h" +// #include "NvInfer.h" + +#include "interpolate_plugin.h" using namespace nvinfer1; @@ -21,236 +23,202 @@ namespace impl { namespace plugins { namespace { -class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt { -private: - at::TensorOptions tensor_options; - DataType dtype; - - std::vector in_shape; - std::vector out_shape; - std::vector size; - std::string mode; - bool align_corners; - -public: - InterpolatePlugin(std::vector in_shape, std::vector out_shape, std::vector size, std::string mode, bool align_corners) : - in_shape(in_shape), out_shape(out_shape), size(size), mode(mode), align_corners(align_corners) - {} - - InterpolatePlugin(const char *data, size_t length) { - std::istringstream data_stream(std::string(data, length)); - - torch::serialize::InputArchive input_archive; - input_archive.load_from(data_stream); - - { - torch::IValue value; - input_archive.read("in_shape", value); - in_shape = value.toIntVector(); - } - { - torch::IValue value; - input_archive.read("out_shape", value); - out_shape = value.toIntVector(); - } - { - torch::IValue value; - input_archive.read("size", value); - size = value.toIntVector(); - } - { - torch::IValue value; - input_archive.read("mode", value); - mode = value.toStringRef(); - } - { - torch::IValue value; - input_archive.read("align_corners", value); - align_corners = value.toBool(); - } - } +/* + * InterpolatePlugin class implementations + */ - int getNbOutputs() const override { - return 1; - } +InterpolatePlugin::InterpolatePlugin(std::vector in_shape, std::vector out_shape, std::vector size, std::string mode, bool align_corners) : + in_shape(in_shape), out_shape(out_shape), size(size), mode(mode), align_corners(align_corners) +{} - const char* getPluginType() const override { - return "Interpolate_TRTorch"; +InterpolatePlugin::InterpolatePlugin(const char *data, size_t length) { + std::istringstream data_stream(std::string(data, length)); + + torch::serialize::InputArchive input_archive; + input_archive.load_from(data_stream); + + { + torch::IValue value; + input_archive.read("in_shape", value); + in_shape = value.toIntVector(); } - - const char* getPluginVersion() const override { - return "1"; + { + torch::IValue value; + input_archive.read("out_shape", value); + out_shape = value.toIntVector(); } - - const char* getPluginNamespace() const override { - return "trtorch"; + { + torch::IValue value; + input_archive.read("size", value); + size = value.toIntVector(); } - - void setPluginNamespace(const char* pluginNamespace) {} - - int getTensorRTVersion() const override { - return NV_TENSORRT_MAJOR; + { + torch::IValue value; + input_archive.read("mode", value); + mode = value.toStringRef(); } - - nvinfer1::IPluginV2DynamicExt* clone() const override { - return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners); + { + torch::IValue value; + input_archive.read("align_corners", value); + align_corners = value.toBool(); } +} - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) override { - //nvinfer1::DimsExprs output(inputs[0]); +int InterpolatePlugin::getNbOutputs() const { + return 1; +} - // output.nbDims = out_shape.size(); +const char* InterpolatePlugin::getPluginType() const { + return "Interpolate_TRTorch"; +} - // for (int i = 0; i < out_shape.size(); i++) { - // output.d[i] = exprBuilder.getConstantValue(out_shape[i]); - // } +const char* InterpolatePlugin::getPluginVersion() const{ + return "1"; +} - // return output; - nvinfer1::DimsExprs empty; - return empty; - } - - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override { - return DataType::kFLOAT; - } +const char* InterpolatePlugin::getPluginNamespace() const { + return "trtorch"; +} - int initialize() override { - tensor_options = tensor_options.device(c10::kCUDA); - tensor_options = tensor_options.dtype(c10::kFloat); +int InterpolatePlugin::getTensorRTVersion() const { + return NV_TENSORRT_MAJOR; +} - return 0; - } +nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone() const { + return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners); +} - void terminate() override {} +nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) { + return inputs[0]; +} - void serialize(void* buffer) const override { - std::string data = serializeToString(); - size_t size = getSerializationSize(); +nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const { + return DataType::kFLOAT; +} - data.copy((char *) buffer, size); - } +int InterpolatePlugin::initialize() { + tensor_options = tensor_options.device(c10::kCUDA); + tensor_options = tensor_options.dtype(c10::kFloat); - std::string serializeToString() const { - torch::serialize::OutputArchive output_archive; + return 0; +} - output_archive.write("in_shape", torch::IValue(in_shape)); - output_archive.write("out_shape", torch::IValue(out_shape)); - output_archive.write("size", torch::IValue(size)); - output_archive.write("mode", torch::IValue(mode)); - output_archive.write("align_corners", torch::IValue(align_corners)); - std::ostringstream data_str; - output_archive.save_to(data_str); +void InterpolatePlugin::serialize(void* buffer) const { + std::string data = serializeToString(); + size_t size = getSerializationSize(); - return data_str.str(); - } + data.copy((char*) buffer, size); +} - size_t getSerializationSize() const override { - return serializeToString().size(); - } +std::string InterpolatePlugin::serializeToString() const { + torch::serialize::OutputArchive output_archive; - void destroy() override {} + output_archive.write("in_shape", torch::IValue(in_shape)); + output_archive.write("out_shape", torch::IValue(out_shape)); + output_archive.write("size", torch::IValue(size)); + output_archive.write("mode", torch::IValue(mode)); + output_archive.write("align_corners", torch::IValue(align_corners)); - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override { - if (inOut->format != nvinfer1::TensorFormat::kLINEAR) { - return false; - } + std::ostringstream data_str; + output_archive.save_to(data_str); - if (inOut->type == DataType::kINT32 || inOut->type == DataType::kINT8) { - return false; - } + return data_str.str(); +} - return true; - } +size_t InterpolatePlugin::getSerializationSize() const { + return serializeToString().size(); +} - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) override { - dtype = DataType::kFLOAT; - } +bool InterpolatePlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) { + if (inOut->format != nvinfer1::TensorFormat::kLINEAR) { + return false; + } - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const override { - return 0; + if (inOut->type == DataType::kINT32 || inOut->type == DataType::kINT8) { + return false; } - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, - cudaStream_t stream) override { - at::Tensor input = at::from_blob((void*) inputs[0], in_shape, [](void*){}, tensor_options); - at::Tensor output = at::from_blob(outputs[0], out_shape, [](void*){}, tensor_options); - - at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool(); - at::cuda::CUDAStreamGuard torch_guard(torch_stream); + return true; +} - cudaEvent_t event; - cudaEventCreate(&event); - cudaEventRecord(event, stream); +void InterpolatePlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) { + dtype = DataType::kFLOAT; +} - cudaStreamWaitEvent(torch_stream.stream(), event, 0); +size_t InterpolatePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const { + return 0; +} - if (mode == "linear") { - at::upsample_linear1d_out(output, input, {size[0]}, align_corners); - } else if (mode == "bilinear") { - at::upsample_bilinear2d_out(output, input, {size[0], size[1]}, align_corners); - } else if (mode == "trilinear") { - at::upsample_trilinear3d_out(output, input, {size[0], size[1], size[2]}, align_corners); - } +int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs, + void *const *outputs, void *workspace, + cudaStream_t stream) { + at::Tensor input = at::from_blob((void*) inputs[0], in_shape, [](void*){}, tensor_options); + at::Tensor output = at::from_blob(outputs[0], out_shape, [](void*){}, tensor_options); - cudaEvent_t torch_event; - cudaEventCreate(&torch_event); - cudaEventRecord(torch_event, torch_stream.stream()); + at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool(); + at::cuda::CUDAStreamGuard torch_guard(torch_stream); - cudaStreamWaitEvent(stream, torch_event, 0); + cudaEvent_t event; + cudaEventCreate(&event); + cudaEventRecord(event, stream); - cudaEventDestroy(event); - cudaEventDestroy(torch_event); + cudaStreamWaitEvent(torch_stream.stream(), event, 0); - return 0; + if (mode == "linear") { + at::upsample_linear1d_out(output, input, {size[0]}, align_corners); + } else if (mode == "bilinear") { + at::upsample_bilinear2d_out(output, input, {size[0], size[1]}, align_corners); + } else if (mode == "trilinear") { + at::upsample_trilinear3d_out(output, input, {size[0], size[1], size[2]}, align_corners); } -}; + cudaEvent_t torch_event; + cudaEventCreate(&torch_event); + cudaEventRecord(torch_event, torch_stream.stream()); -class InterpolatePluginCreator : public nvinfer1::IPluginCreator { -private: - std::string name; + cudaStreamWaitEvent(stream, torch_event, 0); -public: - InterpolatePluginCreator() {} + cudaEventDestroy(event); + cudaEventDestroy(torch_event); - int getTensorRTVersion() const override { - return NV_TENSORRT_MAJOR; - } + return 0; +} - const char *getPluginNamespace() const override { - return "trtorch"; - } +/* + * InterpolatePluginCreator class implementations + */ +const char* InterpolatePluginCreator::getPluginNamespace() const { + return "trtorch"; +} - void setPluginNamespace(const char* libNamespace) override {} - - const char *getPluginName() const override { - return "interpolate"; - } +void InterpolatePluginCreator::setPluginNamespace(const char* libNamespace) {} - const char *getPluginVersion() const override { - return "1"; - } +const char* InterpolatePluginCreator::getPluginName() const { + return "interpolate"; +} - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override { - return nullptr; - } +const char* InterpolatePluginCreator::getPluginVersion() const { + return "1"; +} - nvinfer1::IPluginV2* createPlugin(const char* name, std::vector in_shape, std::vector out_shape, std::vector size, std::string mode, bool align_corners) { - name = name; - return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners); - } +nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) { + return nullptr; +} - nvinfer1::IPluginV2* deserializePlugin(const char* name, const void *serialData, size_t serialLength) override { - name = name; - return new InterpolatePlugin((const char*) serialData, serialLength); - } +nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, std::vector in_shape, std::vector out_shape, std::vector size, std::string mode, bool align_corners) { + name = name; + return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners); +} - const nvinfer1::PluginFieldCollection* getFieldNames() override { - return nullptr; - } -}; +nvinfer1::IPluginV2* InterpolatePluginCreator::deserializePlugin(const char* name, const void *serialData, size_t serialLength) { + name = name; + return new InterpolatePlugin((const char*) serialData, serialLength); +} + +const nvinfer1::PluginFieldCollection* InterpolatePluginCreator::getFieldNames() { + return nullptr; +} REGISTER_TENSORRT_PLUGIN(InterpolatePluginCreator); diff --git a/core/conversion/converters/impl/plugins/interpolate_plugin.h b/core/conversion/converters/impl/plugins/interpolate_plugin.h new file mode 100755 index 0000000000..6afe75c72a --- /dev/null +++ b/core/conversion/converters/impl/plugins/interpolate_plugin.h @@ -0,0 +1,128 @@ +#ifndef TRTORCH_INTERPOLATE_PLUGIN_H +#define TRTORCH_INTERPOLATE_PLUGIN_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/util/prelude.h" +#include "torch/torch.h" +#include "NvInfer.h" + +using namespace nvinfer1; + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace plugins { +namespace { + +class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt { +private: + at::TensorOptions tensor_options; + DataType dtype; + + std::vector in_shape; + std::vector out_shape; + std::vector size; + std::string mode; + bool align_corners; + +protected: + // To prevent compiler warnings + using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; + using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::supportsFormat; + using nvinfer1::IPluginV2DynamicExt::configurePlugin; + using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize; + using nvinfer1::IPluginV2DynamicExt::enqueue; + +public: + InterpolatePlugin(std::vector in_shape, std::vector out_shape, std::vector size, std::string mode, bool align_corners); + + InterpolatePlugin(const char *data, size_t length); + + InterpolatePlugin() = delete; + + int getNbOutputs() const override; + + const char* getPluginType() const override; + + const char* getPluginVersion() const override; + + const char* getPluginNamespace() const override; + + void setPluginNamespace(const char* pluginNamespace) {} + + int getTensorRTVersion() const override; + + nvinfer1::IPluginV2DynamicExt* clone() const override; + + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) override; + + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override; + + int initialize() override; + + void terminate() override {} + + void serialize(void* buffer) const; + + std::string serializeToString() const; + + size_t getSerializationSize() const override; + + void destroy() override {} + + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs, + void *const *outputs, void *workspace, + cudaStream_t stream) override; +}; + +class InterpolatePluginCreator : public nvinfer1::IPluginCreator { +private: + std::string name; + +public: + InterpolatePluginCreator() = default; + + const char* getPluginNamespace() const override; + + void setPluginNamespace(const char* libNamespace) override; + + const char* getPluginName() const override; + + const char* getPluginVersion() const override; + + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override; + + nvinfer1::IPluginV2* createPlugin(const char* name, std::vector in_shape, std::vector out_shape, std::vector size, std::string mode, bool align_corners); + + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void *serialData, size_t serialLength) override; + + const nvinfer1::PluginFieldCollection* getFieldNames() override; +}; + +} // namespace +} // namespace plugins +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch + +#endif // TRTORCH_INTERPOLATE_PLUGIN_H \ No newline at end of file