Skip to content

Commit

Permalink
refactor: Reorganize and consolidate tensor freezing
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Aug 5, 2021
1 parent 152b377 commit 0647d17
Show file tree
Hide file tree
Showing 16 changed files with 114 additions and 71 deletions.
26 changes: 24 additions & 2 deletions core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,32 @@ cc_library(
alwayslink = True,
)


cc_library(
name = "converter_util",
srcs = [
"converter_util.cpp"
],
hdrs = [
"converter_util.h"
],
deps = [
":weights",
"@tensorrt//:nvinfer",
"//core/util:prelude",
"//core/conversion/conversionctx",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
}),
alwayslink = True,
)


cc_library(
name = "converters",
srcs = [
"NodeConverterRegistry.cpp",
"converter_util.cpp",
"impl/activation.cpp",
"impl/batch_norm.cpp",
"impl/concat.cpp",
Expand Down Expand Up @@ -59,7 +80,6 @@ cc_library(
"impl/unsqueeze.cpp",
],
hdrs = [
"converter_util.h",
"converters.h",
],
deps = [
Expand All @@ -69,6 +89,7 @@ cc_library(
"//core/conversion/tensorcontainer",
"//core/conversion/conversionctx",
"//core/plugins:trtorch_plugins",
":converter_util"
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
Expand All @@ -83,6 +104,7 @@ pkg_tar(
srcs = [
"Weights.h",
"converters.h",
"converter_util.h"
],
package_dir = "core/conversion/converters/",
)
4 changes: 2 additions & 2 deletions core/conversion/converters/Weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ Weights::Weights(ConversionCtx* ctx, at::Tensor t) {
}
auto t_cpu = t.to(at::kCPU);
t_cpu = t_cpu.contiguous();
auto dtype_optional = util::toTRTDataType(t_cpu.dtype());
auto dtype_optional = util::optScalarTypeToTRTDataType(t_cpu.scalar_type());
if (!dtype_optional) {
TRTORCH_THROW_ERROR("The tensor requested to be converted to nvinfer1::Weights is of an unsupported type");
TRTORCH_THROW_ERROR("The tensor requested to be converted to nvinfer1::Weights is of an unsupported type: " << dtype_optional.value());
}

// Store the data in the conversion context so it remains until building is
Expand Down
16 changes: 0 additions & 16 deletions core/conversion/converters/Weights.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,6 @@ struct Weights {
friend std::ostream& operator<<(std::ostream& os, const Weights& w);
};

inline nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) {
auto t_weights = Weights(ctx, t);
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
TRTORCH_CHECK(const_layer, "Unable to freeze tensor");

auto out = const_layer->getOutput(0);

std::ostringstream tensor_id;
tensor_id << reinterpret_cast<int*>(out);

LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_id.str() << " as an IConstantLayer");
const_layer->setName(("[Freeze Tensor " + tensor_id.str() + " ]").c_str());

return out;
}

} // namespace converters
} // namespace conversion
} // namespace core
Expand Down
46 changes: 45 additions & 1 deletion core/conversion/converters/converter_util.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "core/conversion/converters/converter_util.h"
#include "core/conversion/converters/converters.h"
#include "core/util/prelude.h"
#include "torch/torch.h"

Expand Down Expand Up @@ -143,6 +142,51 @@ nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nv
}
}

nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) {
bool post_freeze_cast = false;
nvinfer1::DataType post_freeze_cast_type = nvinfer1::DataType::kFLOAT;
// Other "unsupported weights types" can be added to this check here
if (t.scalar_type() == at::kBool) {
post_freeze_cast = true;
auto type = util::ScalarTypeToTRTDataType(t.scalar_type());
post_freeze_cast_type = type;
LOG_DEBUG("To cast layer back to " << post_freeze_cast_type << " from int after freezing");
t = t.to(at::kFloat);
}

auto weights = Weights();
if ((t.scalar_type() == at::kLong || t.scalar_type() == at::kDouble) &&
!ctx->settings.truncate_long_and_double) {
TRTORCH_THROW_ERROR(
"Unable to freeze tensor of type Int64/Float64 into constant layer, try to compile model with truncate_long_and_double enabled");
} else if (t.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) {
weights = converters::Weights(ctx, t.toType(at::kInt));
LOG_WARNING("Truncating weight (constant in the graph) from Int64 to Int32");
} else if (t.scalar_type() == at::kDouble && ctx->settings.truncate_long_and_double) {
weights = converters::Weights(ctx, t.toType(at::kFloat));
LOG_WARNING("Truncating weight (constant in the graph) from Float64 to Float32");
} else {
weights = Weights(ctx, t);
}

auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
TRTORCH_CHECK(const_layer, "Unable to freeze tensor");

auto out = const_layer->getOutput(0);

std::ostringstream tensor_id;
tensor_id << reinterpret_cast<int*>(out);

LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_id.str() << " as an IConstantLayer");
const_layer->setName(("[Freeze Tensor " + tensor_id.str() + " ]").c_str());

if (post_freeze_cast) {
out = castITensor(ctx, out, post_freeze_cast_type);
}

return out;
}

} // namespace converters
} // namespace conversion
} // namespace core
Expand Down
4 changes: 3 additions & 1 deletion core/conversion/converters/converter_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#include "core/conversion/conversionctx/ConversionCtx.h"
#include "core/conversion/converters/Weights.h"
#include "core/conversion/var/Var.h"
#include "core/util/prelude.h"

namespace trtorch {
Expand Down Expand Up @@ -45,6 +44,9 @@ nvinfer1::ILayer* add_elementwise(
// If an ITensor is of a type not dtype, add an Identity layer to cast it to dtype
nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype);

// Freeze an at::Tensor in a IConstant layer
nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t);

} // namespace converters
} // namespace conversion
} // namespace core
Expand Down
1 change: 1 addition & 0 deletions core/conversion/converters/converters.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "torch/csrc/jit/runtime/custom_operator.h"

#include "core/conversion/conversionctx/ConversionCtx.h"
#include "core/conversion/converters/converter_util.h"
#include "core/conversion/converters/Weights.h"
#include "core/conversion/var/Var.h"
#include "core/util/prelude.h"
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
auto input = args[0].ITensor(); // assumes non-static input Tensor
auto orig_shape = input->getDimensions();
auto shape = util::toVec(orig_shape);
auto tensor_type = util::toATenDType(input->getType());
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
auto options = torch::TensorOptions().dtype(tensor_type);

torch::Tensor gamma, beta, mean, var;
Expand Down
5 changes: 1 addition & 4 deletions core/conversion/converters/impl/constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ auto constant_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
// Tensors vs. just Tensors

auto t = args[0].unwrapToTensor();
auto t_weights = Weights(ctx, t);
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
const_layer->setName(util::node_info(n).c_str());
auto const_out = ctx->AssociateValueAndTensor(n->outputs()[0], const_layer->getOutput(0));
auto const_out = ctx->AssociateValueAndTensor(n->outputs()[0], tensor_to_const(ctx, t));

LOG_DEBUG("Output tensor shape: " << const_out->getDimensions());

Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ auto select_registrations TRTORCH_UNUSED =
.pattern({"aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto mask = castITensor(ctx, args[1].ITensorOrFreeze(ctx), nvinfer1::DataType::kBOOL);
auto mask = args[1].ITensorOrFreeze(ctx);
mask = addPadding(ctx, n, mask, self->getDimensions().nbDims, false, true);
auto val = args[2].unwrapToScalar().to<float>();
auto val_t = tensor_to_const(ctx, torch::full(util::toVec(self->getDimensions()), val));
Expand Down
4 changes: 0 additions & 4 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,10 +589,6 @@ auto aten_registrations TRTORCH_UNUSED =
auto dtype = args.at(n->input(1)).IValue();
auto device = args.at(n->input(2)).IValue();
auto tensor = createTensorFromList(*data, *dtype, *device);
if (tensor.dtype() == at::kByte) {
return tensor.to(at::kFloat);
}
std::cout << tensor << std::endl;
return tensor;
},
EvalOptions().validSchemas(
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/var/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ cc_library(
deps = [
"@tensorrt//:nvinfer",
"//core/util:prelude",
"//core/conversion/converters:weights",
"//core/conversion/converters:converter_util",
"//core/conversion/tensorcontainer:tensorcontainer",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
Expand Down
26 changes: 2 additions & 24 deletions core/conversion/var/Var.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <sstream>

#include "core/conversion/converters/converter_util.h"
#include "core/conversion/var/Var.h"
#include "core/util/prelude.h"

Expand Down Expand Up @@ -98,31 +99,8 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {

if (isIValue()) {
if (ptr_.ivalue->isTensor()) {
auto weights = converters::Weights();
auto tensor = ptr_.ivalue->toTensor();
if ((tensor.scalar_type() == at::kLong || tensor.scalar_type() == at::kDouble) &&
!ctx->settings.truncate_long_and_double) {
TRTORCH_THROW_ERROR(
"Unable to freeze tensor of type Int64/Float64 into constant layer, try to compile model with truncate_long_and_double enabled");
} else if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) {
weights = converters::Weights(ctx, tensor.toType(at::kInt));
LOG_WARNING("Truncating weight (constant in the graph) from Int64 to Int32");
} else if (tensor.scalar_type() == at::kDouble && ctx->settings.truncate_long_and_double) {
weights = converters::Weights(ctx, tensor.toType(at::kFloat));
LOG_WARNING("Truncating weight (constant in the graph) from Float64 to Float32");
} else {
weights = converters::Weights(ctx, tensor);
}

auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
TRTORCH_CHECK(const_layer, "Unable to freeze tensor into constant layer");
out = const_layer->getOutput(0);

std::ostringstream tensor_id;
tensor_id << reinterpret_cast<int*>(out);

LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_id.str() << " as an IConstantLayer");
const_layer->setName(("[Freeze Tensor " + tensor_id.str() + " ]").c_str());
out = converters::tensor_to_const(ctx, tensor);
} else {
// Split converter generates c10::IValue which hold TensorContainer.
auto output_container = ptr_.ivalue->toCustomClass<TensorContainer>();
Expand Down
1 change: 0 additions & 1 deletion core/conversion/var/Var.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <string>

#include "core/conversion/conversionctx/ConversionCtx.h"
#include "core/conversion/converters/Weights.h"
#include "core/conversion/tensorcontainer/TensorContainer.h"
#include "core/util/prelude.h"
#include "torch/csrc/jit/ir/ir.h"
Expand Down
4 changes: 2 additions & 2 deletions core/runtime/register_trt_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
TRTORCH_CHECK(
inputs[pyt_idx].is_cuda(),
"Expected input tensors to have device cuda, found device " << inputs[pyt_idx].device());
auto expected_type = util::toATenDType(compiled_engine->exec_ctx->getEngine().getBindingDataType(i));
auto expected_type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getBindingDataType(i));
TRTORCH_CHECK(
inputs[pyt_idx].dtype() == expected_type,
"Expected input tensors to have type " << expected_type << ", found type " << inputs[pyt_idx].dtype());
Expand All @@ -131,7 +131,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
auto out_shape = compiled_engine->exec_ctx->getBindingDimensions(o);
LOG_DEBUG("Output shape: " << out_shape);
auto dims = core::util::toVec(out_shape);
auto type = util::toATenDType(compiled_engine->exec_ctx->getEngine().getBindingDataType(o));
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getBindingDataType(o));
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());
gpu_handles.push_back(outputs[pyt_idx].data_ptr());
}
Expand Down
34 changes: 26 additions & 8 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,25 +260,43 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_aten_type_
return get_trt_at_type_map();
}

at::ScalarType toATenDType(nvinfer1::DataType t) {
at::ScalarType TRTDataTypeToScalarType(nvinfer1::DataType t) {
auto type = optTRTDataTypeToScalarType(t);
TRTORCH_CHECK(type, "Unsupported TensorRT data type " << t);
return type.value();
}

c10::optional<at::ScalarType> optTRTDataTypeToScalarType(nvinfer1::DataType t) {
auto trt_aten_type_map = get_trt_aten_type_map();
TRTORCH_CHECK(trt_aten_type_map.find(t) != trt_aten_type_map.end(), "Unsupported TensorRT datatype");
return trt_aten_type_map.at(t);
if (trt_aten_type_map.find(t) != trt_aten_type_map.end()) {
return trt_aten_type_map.at(t);
} else {
return {};
}
}

const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_aten_trt_type_map() {
return get_at_trt_type_map();
}

nvinfer1::DataType toTRTDataType(at::ScalarType t) {
nvinfer1::DataType ScalarTypeToTRTDataType(at::ScalarType t) {
auto type = optScalarTypeToTRTDataType(t);
TRTORCH_CHECK(type, "Unsupported ATen data type " << t);
return type.value();
}

c10::optional<nvinfer1::DataType> optScalarTypeToTRTDataType(at::ScalarType t) {
auto aten_trt_type_map = get_aten_trt_type_map();
TRTORCH_CHECK(aten_trt_type_map.find(t) != aten_trt_type_map.end(), "Unsupported Aten datatype");
return aten_trt_type_map.at(t);
if (aten_trt_type_map.find(t) != aten_trt_type_map.end()) {
return aten_trt_type_map.at(t);
} else {
return {};
}
}

c10::optional<nvinfer1::DataType> toTRTDataType(caffe2::TypeMeta dtype) {
c10::optional<nvinfer1::DataType> optTypeMetaToTRTDataType(caffe2::TypeMeta dtype) {
if (auto t = c10::optTypeMetaToScalarType(dtype)) {
return toTRTDataType(t.value());
return optScalarTypeToTRTDataType(t.value());
} else {
return {};
}
Expand Down
8 changes: 5 additions & 3 deletions core/util/trt_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,11 @@ nvinfer1::DimsHW toDimsHW(c10::IntArrayRef l);
std::vector<int64_t> toVec(nvinfer1::Dims d);
std::string toStr(nvinfer1::Dims d);

at::ScalarType toATenDType(nvinfer1::DataType t);
nvinfer1::DataType toTRTDataType(at::ScalarType t);
c10::optional<nvinfer1::DataType> toTRTDataType(caffe2::TypeMeta dtype);
at::ScalarType TRTDataTypeToScalarType(nvinfer1::DataType t);
c10::optional<at::ScalarType> optTRTDataTypeToScalarType(nvinfer1::DataType t);
nvinfer1::DataType ScalarTypeToTRTDataType(at::ScalarType t);
c10::optional<nvinfer1::DataType> optScalarTypeToTRTDataType(at::ScalarType t);
c10::optional<nvinfer1::DataType> optTypeMetaToTRTDataType(caffe2::TypeMeta dtype);
torch::jit::Value* getOrAddInputForValue(
torch::jit::Value* old_value,
std::shared_ptr<torch::jit::Graph>& graph,
Expand Down

0 comments on commit 0647d17

Please sign in to comment.