Skip to content

Commit

Permalink
fix(plugin): trying to fix bug in plugin
Browse files Browse the repository at this point in the history
Signed-off-by: Abhiram Iyer <abhirami@nvidia.com>

Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com>
  • Loading branch information
abhi-iyer committed Jun 17, 2020
1 parent f0fefaa commit cafcced
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
19 changes: 13 additions & 6 deletions core/conversion/converters/impl/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,26 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()

TRTORCH_ASSERT(out_size.size() == 1, "aten::upsample_linear1d input Tensor and output size dimension mismatch");

auto out_shape = in_shape;
auto out_shape = in_shape;
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));

if (!align_corners) {
//auto creator = getPluginRegistry()->getPluginCreator("interpolate", "1");
//auto* plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);
auto creator = new plugins::InterpolatePluginCreator();
//auto plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);
std::raise(SIGINT);

auto plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);
//auto creator_auto = getPluginRegistry()->getPluginCreator("interpolate", "1");
//auto plugin_auto = creator_auto->createPlugin(util::node_info(n).c_str(), nullptr);

auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(in), 1, *plugin);
//auto creator = getPluginRegistry()->getPluginCreator("interpolate", "1");

auto creator = new plugins::InterpolatePluginCreator();
auto plugin = creator->createPlugin("interpolate_plugin", in_shape, out_shape, out_size, std::string("linear"), align_corners);

auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
resize_layer->setName(util::node_info(n).c_str());

auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));

LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
} else {
auto resize_layer = ctx->net->addResize(*in);
Expand Down
16 changes: 13 additions & 3 deletions core/conversion/converters/impl/plugins/interpolate_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ namespace conversion {
namespace converters {
namespace impl {
namespace plugins {
namespace {

/*
* InterpolatePlugin class implementations
Expand Down Expand Up @@ -64,6 +63,18 @@ InterpolatePlugin::InterpolatePlugin(const char *data, size_t length) {
}
}

std::vector<int64_t> InterpolatePlugin::getInputShape() {
return in_shape;
}

std::vector<int64_t> InterpolatePlugin::getOutputShape() {
return out_shape;
}

std::vector<int64_t> InterpolatePlugin::getOutputSize() {
return size;
}

int InterpolatePlugin::getNbOutputs() const {
return 1;
}
Expand Down Expand Up @@ -206,7 +217,7 @@ nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, co
return nullptr;
}

nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners) {
nvinfer1::IPluginV2DynamicExt* InterpolatePluginCreator::createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners) {
name = name;
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
}
Expand All @@ -222,7 +233,6 @@ const nvinfer1::PluginFieldCollection* InterpolatePluginCreator::getFieldNames()

REGISTER_TENSORRT_PLUGIN(InterpolatePluginCreator);

} // namespace
} // namespace plugins
} // namespace impl
} // namespace converters
Expand Down
10 changes: 7 additions & 3 deletions core/conversion/converters/impl/plugins/interpolate_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ namespace conversion {
namespace converters {
namespace impl {
namespace plugins {
namespace {

class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
private:
Expand Down Expand Up @@ -52,6 +51,12 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {

InterpolatePlugin() = delete;

std::vector<int64_t> getInputShape();

std::vector<int64_t> getOutputShape();

std::vector<int64_t> getOutputSize();

int getNbOutputs() const override;

const char* getPluginType() const override;
Expand Down Expand Up @@ -110,14 +115,13 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {

nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override;

nvinfer1::IPluginV2* createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners);
nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners);

nvinfer1::IPluginV2* deserializePlugin(const char* name, const void *serialData, size_t serialLength) override;

const nvinfer1::PluginFieldCollection* getFieldNames() override;
};

} // namespace
} // namespace plugins
} // namespace impl
} // namespace converters
Expand Down

0 comments on commit cafcced

Please sign in to comment.