diff --git a/core/conversion/converters/impl/interpolate.cpp b/core/conversion/converters/impl/interpolate.cpp index c7f1df07b2..cba869ed86 100755 --- a/core/conversion/converters/impl/interpolate.cpp +++ b/core/conversion/converters/impl/interpolate.cpp @@ -118,19 +118,26 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() TRTORCH_ASSERT(out_size.size() == 1, "aten::upsample_linear1d input Tensor and output size dimension mismatch"); - auto out_shape = in_shape; + auto out_shape = in_shape; std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size())); if (!align_corners) { - //auto creator = getPluginRegistry()->getPluginCreator("interpolate", "1"); - //auto* plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners); - auto creator = new plugins::InterpolatePluginCreator(); + //auto plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners); + std::raise(SIGINT); - auto plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners); + //auto creator_auto = getPluginRegistry()->getPluginCreator("interpolate", "1"); + //auto plugin_auto = creator_auto->createPlugin(util::node_info(n).c_str(), nullptr); - auto resize_layer = ctx->net->addPluginV2(reinterpret_cast(in), 1, *plugin); + //auto creator = getPluginRegistry()->getPluginCreator("interpolate", "1"); + + auto creator = new plugins::InterpolatePluginCreator(); + auto plugin = creator->createPlugin("interpolate_plugin", in_shape, out_shape, out_size, std::string("linear"), align_corners); + + auto resize_layer = ctx->net->addPluginV2(reinterpret_cast(&in), 1, *plugin); + resize_layer->setName(util::node_info(n).c_str()); auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions()); } else { auto resize_layer = ctx->net->addResize(*in); diff --git a/core/conversion/converters/impl/plugins/interpolate_plugin.cpp b/core/conversion/converters/impl/plugins/interpolate_plugin.cpp index a38e9e8674..6db6bca7e6 100755 --- a/core/conversion/converters/impl/plugins/interpolate_plugin.cpp +++ b/core/conversion/converters/impl/plugins/interpolate_plugin.cpp @@ -21,7 +21,6 @@ namespace conversion { namespace converters { namespace impl { namespace plugins { -namespace { /* * InterpolatePlugin class implementations @@ -64,6 +63,18 @@ InterpolatePlugin::InterpolatePlugin(const char *data, size_t length) { } } +std::vector InterpolatePlugin::getInputShape() { + return in_shape; +} + +std::vector InterpolatePlugin::getOutputShape() { + return out_shape; +} + +std::vector InterpolatePlugin::getOutputSize() { + return size; +} + int InterpolatePlugin::getNbOutputs() const { return 1; } @@ -206,7 +217,7 @@ nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, co return nullptr; } -nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, std::vector in_shape, std::vector out_shape, std::vector size, std::string mode, bool align_corners) { +nvinfer1::IPluginV2DynamicExt* InterpolatePluginCreator::createPlugin(const char* name, std::vector in_shape, std::vector out_shape, std::vector size, std::string mode, bool align_corners) { name = name; return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners); } @@ -222,7 +233,6 @@ const nvinfer1::PluginFieldCollection* InterpolatePluginCreator::getFieldNames() REGISTER_TENSORRT_PLUGIN(InterpolatePluginCreator); -} // namespace } // namespace plugins } // namespace impl } // namespace converters diff --git a/core/conversion/converters/impl/plugins/interpolate_plugin.h b/core/conversion/converters/impl/plugins/interpolate_plugin.h index 6afe75c72a..67a278b49c 100755 --- a/core/conversion/converters/impl/plugins/interpolate_plugin.h +++ b/core/conversion/converters/impl/plugins/interpolate_plugin.h @@ -22,7 +22,6 @@ namespace conversion { namespace converters { namespace impl { namespace plugins { -namespace { class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt { private: @@ -52,6 +51,12 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt { InterpolatePlugin() = delete; + std::vector getInputShape(); + + std::vector getOutputShape(); + + std::vector getOutputSize(); + int getNbOutputs() const override; const char* getPluginType() const override; @@ -110,14 +115,13 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator { nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override; - nvinfer1::IPluginV2* createPlugin(const char* name, std::vector in_shape, std::vector out_shape, std::vector size, std::string mode, bool align_corners); + nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name, std::vector in_shape, std::vector out_shape, std::vector size, std::string mode, bool align_corners); nvinfer1::IPluginV2* deserializePlugin(const char* name, const void *serialData, size_t serialLength) override; const nvinfer1::PluginFieldCollection* getFieldNames() override; }; -} // namespace } // namespace plugins } // namespace impl } // namespace converters