diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 42e4972362..cb4b998cd4 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -7,6 +7,25 @@ config_setting( } ) +cc_library( + name = "weights", + hdrs = [ + "Weights.h" + ], + srcs = [ + "Weights.cpp" + ], + deps = [ + "@tensorrt//:nvinfer", + "//core/util:prelude", + "//core/conversion/conversionctx" + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }), + alwayslink = True, +) + cc_library( name = "converters", hdrs = [ @@ -14,7 +33,6 @@ cc_library( ], srcs = [ "NodeConverterRegistry.cpp", - "Weights.cpp", "impl/activation.cpp", "impl/batch_norm.cpp", "impl/concat.cpp", @@ -51,5 +69,8 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar") pkg_tar( name = "include", package_dir = "core/conversion/converters/", - srcs = ["converters.h"], + srcs = [ + "converters.h", + "Weights.h" + ], ) diff --git a/core/conversion/converters/Weights.cpp b/core/conversion/converters/Weights.cpp index 06d5a4d050..7e0519cc50 100644 --- a/core/conversion/converters/Weights.cpp +++ b/core/conversion/converters/Weights.cpp @@ -1,12 +1,11 @@ #include "core/util/prelude.h" -#include "core/conversion/converters/converters.h" +#include "core/conversion/converters/Weights.h" namespace trtorch { namespace core { namespace conversion { namespace converters { - Weights::Weights() { this->num_input_maps = 0; this->num_output_maps = 0; @@ -18,20 +17,36 @@ Weights::Weights() { Weights::Weights(ConversionCtx* ctx, float val) { this->num_input_maps = 1; this->num_output_maps = 1; + this->data.type = nvinfer1::DataType::kFLOAT; float* buf = reinterpret_cast(malloc(1 * sizeof(float))); buf[0] = val; this->data.values = buf; this->data.count = 1; ctx->builder_resources.push_back(buf); - this->kernel_shape.nbDims = 1; - this->kernel_shape.d[0] = 1; + + this->shape.nbDims = 0; + this->kernel_shape.nbDims = 0; +} + +Weights::Weights(ConversionCtx* ctx, int32_t val) { + this->num_input_maps = 1; + this->num_output_maps = 1; + + this->data.type = nvinfer1::DataType::kINT32; + int32_t* buf = reinterpret_cast(malloc(1 * sizeof(int32_t))); + buf[0] = val; + this->data.values = buf; + this->data.count = 1; + ctx->builder_resources.push_back(buf); + + this->shape.nbDims = 0; + this->kernel_shape.nbDims = 0; } Weights::Weights(ConversionCtx* ctx, at::Tensor t) { if (t.sizes().size() > nvinfer1::Dims::MAX_DIMS) { - //TODO: Handle this with exceptions or whatever - LOG_INTERNAL_ERROR("The tensor requested to be converted to nvinfer1::Weights exceeds the max number of dimensions for TensorRT"); + TRTORCH_THROW_ERROR("The tensor requested to be converted to nvinfer1::Weights exceeds the max number of dimensions for TensorRT"); } this->shape = util::toDims(t.sizes()); if (t.sizes().size() >= 2) { @@ -59,9 +74,7 @@ Weights::Weights(ConversionCtx* ctx, at::Tensor t) { t_cpu = t_cpu.contiguous(); auto dtype_optional = util::toTRTDataType(t_cpu.dtype()); if (!dtype_optional) { - //TODO: Handle this with exceptions or whatever - //TODO: Implement handling for the Torch Types - LOG_INTERNAL_ERROR("The tensor requested to be converted to nvinfer1::Weights is of an unsupported type"); + TRTORCH_THROW_ERROR("The tensor requested to be converted to nvinfer1::Weights is of an unsupported type"); } // Store the data in the conversion context so it remains until building is complete diff --git a/core/conversion/converters/Weights.h b/core/conversion/converters/Weights.h new file mode 100755 index 0000000000..78f1db3842 --- /dev/null +++ b/core/conversion/converters/Weights.h @@ -0,0 +1,45 @@ +#pragma once + +#include "core/util/prelude.h" +#include "core/conversion/conversionctx/ConversionCtx.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { + +struct Weights { + nvinfer1::Weights data; + nvinfer1::Dims kernel_shape; + nvinfer1::Dims shape; + int64_t num_input_maps; + int64_t num_output_maps; + + Weights(); + Weights(ConversionCtx* ctx, at::Tensor t); + Weights(ConversionCtx* ctx, float val); + Weights(ConversionCtx* ctx, int32_t val); + friend std::ostream& operator<<(std::ostream& os, const Weights& w); +}; + +inline nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) { + auto t_weights = Weights(ctx, t); + auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); + TRTORCH_CHECK(const_layer, "Unable to freeze tensor"); + + auto out = const_layer->getOutput(0); + + std::ostringstream tensor_id; + tensor_id << reinterpret_cast(out); + + LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_id.str() << " as an IConstantLayer"); + const_layer->setName(("[Freeze Tensor " + tensor_id.str() + " ]").c_str()); + + return out; +} + + +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch \ No newline at end of file diff --git a/core/conversion/converters/converters.h b/core/conversion/converters/converters.h index 18b1fc376d..42cb96d010 100644 --- a/core/conversion/converters/converters.h +++ b/core/conversion/converters/converters.h @@ -9,6 +9,7 @@ #include "core/util/prelude.h" #include "core/conversion/var/Var.h" #include "core/conversion/conversionctx/ConversionCtx.h" +#include "core/conversion/converters/Weights.h" namespace trtorch { namespace core { @@ -39,28 +40,6 @@ class RegisterNodeConversionPatterns { bool node_is_convertable(const torch::jit::Node* n); OpConverter get_node_converter_for(const torch::jit::FunctionSchema* signature); -struct Weights { - //TODO: Rebuild this in a way that makes sense for more than just conv2/3D and linear - nvinfer1::Weights data; - nvinfer1::Dims kernel_shape; - nvinfer1::Dims shape; - int64_t num_input_maps; - int64_t num_output_maps; - - Weights(); - Weights(ConversionCtx* ctx, at::Tensor t); - Weights(ConversionCtx* ctx, float val); - friend std::ostream& operator<<(std::ostream& os, const Weights& w); -}; - -inline nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) { - auto t_weights = Weights(ctx, t); - auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); - TRTORCH_CHECK(const_layer, "Unable to freeze tensor"); - const_layer->setName("[Freeze Tensor]"); - return const_layer->getOutput(0); -} - } // namespace converters } // namespace conversion } // namespace core diff --git a/core/conversion/converters/impl/activation.cpp b/core/conversion/converters/impl/activation.cpp index 31ac992414..fe15335821 100644 --- a/core/conversion/converters/impl/activation.cpp +++ b/core/conversion/converters/impl/activation.cpp @@ -10,7 +10,7 @@ namespace { #define convert(act, trt_type) \ bool act(ConversionCtx* ctx, const torch::jit::Node* n, args& args) { \ - auto in = args[0].ITensor(); \ + auto in = args[0].ITensorOrFreeze(ctx); \ \ auto new_layer = \ ctx->net->addActivation(*in, nvinfer1::ActivationType::trt_type); \ @@ -46,7 +46,7 @@ auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns() .pattern({ "aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto min = args[1].unwrapToDouble(); auto max = args[2].unwrapToDouble(); @@ -66,7 +66,7 @@ auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns() //TODO: Remove after functionalization "aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor(a!))", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto min = args[1].unwrapToDouble(); auto max = args[2].unwrapToDouble(); diff --git a/core/conversion/converters/impl/batch_norm.cpp b/core/conversion/converters/impl/batch_norm.cpp index a7b6045737..3bac690d18 100644 --- a/core/conversion/converters/impl/batch_norm.cpp +++ b/core/conversion/converters/impl/batch_norm.cpp @@ -15,7 +15,7 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() Tensor? mean, Tensor? var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto input = args[0].ITensor(); + auto input = args[0].ITensor(); // assumes non-static input Tensor auto orig_shape = input->getDimensions(); auto shape = util::toVec(orig_shape); auto options = torch::TensorOptions().dtype(torch::kFloat32); diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index 8a76557381..2cb89d5463 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -17,7 +17,7 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor))SIG", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + auto in = args[0].ITensor(); // assumes non-static input Tensor auto w = Weights(ctx, args[1].unwrapToTensor()); auto stride = util::toDims(args[3].unwrapToIntList()); diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index 4cb2e03a19..6badba97ef 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -26,7 +26,6 @@ nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOpera self = self_shuffle->getOutput(0); } - nvinfer1::ILayer* ele; if (scalar != 1) { LOG_WARNING("Please verify scalar handling in add converter, channel axis set to 3 but scaling is uniform"); @@ -73,8 +72,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns( "aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // Should implement self + alpha * other - auto self = args[0].ITensor(); - auto other = args[1].ITensor(); + auto self = args[0].ITensorOrFreeze(ctx); + auto other = args[1].ITensorOrFreeze(ctx); auto scalar = args[2].unwrapToScalar().to(); auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, util::node_info(n), scalar); @@ -90,8 +89,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns( "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // Should implement self + alpha * other - auto self = args[0].ITensor(); - auto other = args[1].ITensor(); + auto self = args[0].ITensorOrFreeze(ctx); + auto other = args[1].ITensorOrFreeze(ctx); auto scalar = args[2].unwrapToScalar().to(); auto add = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, self, other, util::node_info(n), scalar); @@ -107,8 +106,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns( "aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // Should implement self - alpha * other - auto self = args[0].ITensor(); - auto other = args[1].ITensor(); + auto self = args[0].ITensorOrFreeze(ctx); + auto other = args[1].ITensorOrFreeze(ctx); auto scalar = args[2].unwrapToScalar().to(); auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, other, util::node_info(n), scalar); @@ -124,8 +123,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns( "aten::div.Tensor(Tensor self, Tensor other) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // Should implement self / other - auto self = args[0].ITensor(); - auto other = args[1].ITensor(); + auto self = args[0].ITensorOrFreeze(ctx); + auto other = args[1].ITensorOrFreeze(ctx); auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n); @@ -140,8 +139,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns( "aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // TODO: Remove with functionalization - auto self = args[0].ITensor(); - auto other = args[1].ITensor(); + auto self = args[0].ITensorOrFreeze(ctx); + auto other = args[1].ITensorOrFreeze(ctx); auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n)); TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n); @@ -156,8 +155,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns( "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // Should implement self * other - auto self = args[0].ITensor(); - auto other = args[1].ITensor(); + auto self = args[0].ITensorOrFreeze(ctx); + auto other = args[1].ITensorOrFreeze(ctx); auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n)); TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n); @@ -172,8 +171,8 @@ auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns( "aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // TODO: Remove with functionalization - auto self = args[0].ITensor(); - auto other = args[1].ITensor(); + auto self = args[0].ITensorOrFreeze(ctx); + auto other = args[1].ITensorOrFreeze(ctx); auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n)); TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n); diff --git a/core/conversion/converters/impl/linear.cpp b/core/conversion/converters/impl/linear.cpp index e22664afe0..3fade058ac 100644 --- a/core/conversion/converters/impl/linear.cpp +++ b/core/conversion/converters/impl/linear.cpp @@ -14,7 +14,7 @@ auto linear_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // PyTorch follows in: Nx*xIN, W: OUTxIN, B: OUT, out: Nx*xOUT // TensorRT inserts a flatten in when following conv - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto shape = util::toVec(in->getDimensions()); LOG_DEBUG("Input tensor shape: " << in->getDimensions()); diff --git a/core/conversion/converters/impl/matrix_multiply.cpp b/core/conversion/converters/impl/matrix_multiply.cpp index cbebdc13b2..07e32cc1f0 100644 --- a/core/conversion/converters/impl/matrix_multiply.cpp +++ b/core/conversion/converters/impl/matrix_multiply.cpp @@ -12,30 +12,10 @@ auto mm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() .pattern({ "aten::matmul(Tensor self, Tensor other) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - nvinfer1::ITensor* self; - if (args[0].isIValue()) { - auto t = args[0].unwrapToTensor(); - auto t_weights = Weights(ctx, t); - auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); - TRTORCH_CHECK(const_layer, "Unable to freeze tensor self for node: " << *n); - const_layer->setName((util::node_info(n) + " [Freeze Tensor(self)]").c_str()); - self = const_layer->getOutput(0); - } else { - self = args[0].ITensor(); - } + auto self = args[0].ITensorOrFreeze(ctx); LOG_DEBUG("self tensor shape: " << self->getDimensions()); - nvinfer1::ITensor* other; - if (args[1].isIValue()) { - auto t = args[1].unwrapToTensor(); - auto t_weights = Weights(ctx, t); - auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data); - TRTORCH_CHECK(const_layer, "Unable to freeze tensor other for node: " << *n); - const_layer->setName((util::node_info(n) + " [Freeze Tensor(other)]").c_str()); - other = const_layer->getOutput(0); - } else { - other = args[1].ITensor(); - } + auto other = args[1].ITensorOrFreeze(ctx); LOG_DEBUG("other tensor shape: " << other->getDimensions()); auto mm_layer = ctx->net->addMatrixMultiply(*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE); diff --git a/core/conversion/converters/impl/pooling.cpp b/core/conversion/converters/impl/pooling.cpp index c83c90fef0..7c42927824 100644 --- a/core/conversion/converters/impl/pooling.cpp +++ b/core/conversion/converters/impl/pooling.cpp @@ -11,7 +11,7 @@ namespace impl { namespace { bool MaxPoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& args) { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto shape = util::toVec(in->getDimensions()); // Max Pool needs at least 4D input @@ -65,7 +65,7 @@ bool MaxPoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& ar } bool AvgPoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& args) { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto shape = util::toVec(in->getDimensions()); // Avg Pool needs at least 4D input @@ -122,7 +122,7 @@ auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() .pattern({ "aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto shape = util::toVec(in->getDimensions()); // Max Pool needs at least 4D input @@ -182,7 +182,7 @@ auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() }).pattern({ "aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto shape = util::toVec(in->getDimensions()); // Avg Pool needs at least 4D input @@ -262,7 +262,7 @@ auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() }).pattern({ "aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto in_shape = util::toVec(in->getDimensions()); if (in_shape.size() < 4) { diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index 16e0d9dd83..3d8c23d1b9 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -15,7 +15,7 @@ auto reduce_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() .pattern({ "aten::mean(Tensor self, *, ScalarType? dtype=None) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in_tensor = args[0].ITensor(); + auto in_tensor = args[0].ITensorOrFreeze(ctx); auto in_dims = util::toVec(in_tensor->getDimensions()); LOG_WARNING("Mean Converter disregards dtype"); @@ -34,7 +34,7 @@ auto reduce_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() }).pattern({ "aten::mean.dim(Tensor self, int[] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in_tensor = args[0].ITensor(); + auto in_tensor = args[0].ITensorOrFreeze(ctx); auto dims = args[1].unwrapToIntList(); LOG_DEBUG("Dim to reduce:" << util::toDims(dims)); // Some abuse of toDim but just for debug info @@ -61,7 +61,7 @@ auto reduce_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() }).pattern({ "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in_tensor = args[0].ITensor(); + auto in_tensor = args[0].ITensorOrFreeze(ctx); auto in_dims = util::toVec(in_tensor->getDimensions()); LOG_WARNING("Sum Converter disregards dtype"); @@ -80,7 +80,7 @@ auto reduce_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() }).pattern({ "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in_tensor = args[0].ITensor(); + auto in_tensor = args[0].ITensorOrFreeze(ctx); auto dims = args[1].unwrapToIntList(); LOG_DEBUG("Dim to reduce:" << util::toDims(dims)); // Some abuse of toDim but just for debug info @@ -107,7 +107,7 @@ auto reduce_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() }).pattern({ "aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in_tensor = args[0].ITensor(); + auto in_tensor = args[0].ITensorOrFreeze(ctx); auto in_dims = util::toVec(in_tensor->getDimensions()); LOG_WARNING("Prod Converter disregards dtype"); @@ -126,7 +126,7 @@ auto reduce_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() }).pattern({ "aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in_tensor = args[0].ITensor(); + auto in_tensor = args[0].ITensorOrFreeze(ctx); auto dim = args[1].unwrapToInt(); LOG_DEBUG("Dim to reduce:" << dim); // Some abuse of toDim but just for debug info @@ -150,7 +150,7 @@ auto reduce_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() }).pattern({ "aten::max(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in_tensor = args[0].ITensor(); + auto in_tensor = args[0].ITensorOrFreeze(ctx); auto in_dims = util::toVec(in_tensor->getDimensions()); uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1); @@ -168,7 +168,7 @@ auto reduce_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() }).pattern({ "aten::min(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in_tensor = args[0].ITensor(); + auto in_tensor = args[0].ITensorOrFreeze(ctx); auto in_dims = util::toVec(in_tensor->getDimensions()); uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1); diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index aa06b9f5e9..6457c42883 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -13,7 +13,7 @@ static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern .pattern({ "aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto start_dim = args[1].unwrapToInt(); auto end_dim = args[2].unwrapToInt(); auto in_shape = util::toVec(in->getDimensions()); @@ -36,7 +36,7 @@ static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern }).pattern({ "aten::reshape(Tensor self, int[] shape) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto in_shape = util::toVec(in->getDimensions()); std::vector new_shape; if (ctx->input_is_dynamic) { @@ -58,7 +58,7 @@ static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern }).pattern({ "aten::view(Tensor(a) self, int[] size) -> (Tensor(a))", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto in_shape = util::toVec(in->getDimensions()); auto shuffle = ctx->net->addShuffle(*in); @@ -74,7 +74,7 @@ static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern }).pattern({ "aten::permute(Tensor(a) self, int[] dims) -> (Tensor(a))", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto in_shape = util::toVec(in->getDimensions()); auto new_order = args[1].unwrapToIntList().vec(); diff --git a/core/conversion/converters/impl/softmax.cpp b/core/conversion/converters/impl/softmax.cpp index 6a81b974a2..6e5213774b 100644 --- a/core/conversion/converters/impl/softmax.cpp +++ b/core/conversion/converters/impl/softmax.cpp @@ -11,7 +11,7 @@ static auto softmax_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern .pattern({ "aten::softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); + auto in = args[0].ITensorOrFreeze(ctx); auto shape = util::toVec(in->getDimensions()); // SoftMax needs at least 4D input diff --git a/core/conversion/var/BUILD b/core/conversion/var/BUILD index 247f939e48..fb4ee14175 100644 --- a/core/conversion/var/BUILD +++ b/core/conversion/var/BUILD @@ -19,6 +19,7 @@ cc_library( deps = [ "@tensorrt//:nvinfer", "//core/util:prelude", + "//core/conversion/converters:weights" ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], "//conditions:default": ["@libtorch//:libtorch"], diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index 7417d74db7..6aa2a200b0 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -1,3 +1,5 @@ +#include + #include "core/util/prelude.h" #include "core/conversion/var/Var.h" @@ -85,6 +87,36 @@ std::string Var::type_name() const { } } +nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) { + if (isIValue()) { + LOG_DEBUG(ctx->logger, "Found IValue containing object of type " << *(ptr_.ivalue->type())); + } + TRTORCH_CHECK(isITensor() || (isIValue() && ptr_.ivalue->isTensor()), "Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name()); + + nvinfer1::ITensor* out; + + if (isIValue()) { + auto weights = converters::Weights(ctx, ptr_.ivalue->toTensor()); + + auto const_layer = ctx->net->addConstant(weights.shape, weights.data); + TRTORCH_CHECK(const_layer, "Unable to freeze tensor into constant layer"); + + out = const_layer->getOutput(0); + + std::ostringstream tensor_id; + tensor_id << reinterpret_cast(out); + + LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_id.str() << " as an IConstantLayer"); + const_layer->setName(("[Freeze Tensor " + tensor_id.str() + " ]").c_str()); + } else { + out = ptr_.tensor; + } + + LOG_DEBUG("Frozen tensor shape: " << out->getDimensions()); + + return out; +} + const torch::jit::IValue* Var::IValue() const { TRTORCH_CHECK(isIValue(), "Requested IValue from Var, however Var type is " << type_name()); if (type_ == Type::kIValue) { diff --git a/core/conversion/var/Var.h b/core/conversion/var/Var.h index 526db246b8..2b57c35241 100644 --- a/core/conversion/var/Var.h +++ b/core/conversion/var/Var.h @@ -4,7 +4,8 @@ #include #include "torch/csrc/jit/ir/ir.h" - +#include "core/conversion/conversionctx/ConversionCtx.h" +#include "core/conversion/converters/Weights.h" #include "core/util/prelude.h" namespace trtorch { @@ -48,6 +49,7 @@ class Var : torch::CustomClassHolder { c10::List unwrapToBoolList(); c10::List unwrapToTensorList(c10::List default_val); c10::List unwrapToTensorList(); + nvinfer1::ITensor* ITensorOrFreeze(ConversionCtx* ctx); template T unwrapTo(T default_val);