Skip to content

Commit

Permalink
Merge pull request #1647 from pytorch/aten_size_fix
Browse files Browse the repository at this point in the history
feat(//core/conversion): Add support for aten::size with dynamic shaped models for Torchscript backend.
  • Loading branch information
peri044 authored and bowang007 committed Apr 28, 2023
1 parent 013934a commit d5edd7c
Show file tree
Hide file tree
Showing 13 changed files with 445 additions and 209 deletions.
2 changes: 1 addition & 1 deletion core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
return {};
}
}
auto eval = evaluators::EvalNode(n, eval_args);
auto eval = evaluators::EvalNode(ctx, n, eval_args);
return eval;
}

Expand Down
36 changes: 24 additions & 12 deletions core/conversion/converters/impl/shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,37 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
auto in = args[0].ITensorOrFreeze(ctx);
auto in_shape = util::toVec(in->getDimensions());
std::vector<int64_t> new_shape;
nvinfer1::ITensor* shape_tensor;
if (ctx->input_is_dynamic) {
new_shape = util::toVec(args[1].unwrapToIntList().vec());
int nbDynamicDims = 0;
for (size_t i = 0; i < new_shape.size(); i++) {
if (in_shape[i] == -1)
nbDynamicDims++;
}
if (nbDynamicDims > 1) {
TORCHTRT_THROW_ERROR(
"Resize is currently not supported when target shape contains more than one dynamic dimension");
LOG_DEBUG("Using dynamic version of reshape layer");
if (args[1].isITensorList()) {
LOG_DEBUG("Shape tensor is an ITensorList");
auto new_shape = args[1].unwrapToITensorList();
auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size());
TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n);
concat_layer->setAxis(static_cast<int32_t>(0));
shape_tensor = concat_layer->getOutput(0);
} else if (args[1].isIntList()) {
LOG_DEBUG("Shape tensor is an IntList");
auto shape_vec = args[1].unwrapToIntList().vec();
shape_tensor = tensor_to_const(ctx, torch::tensor(shape_vec).to(torch::kI32));
} else {
LOG_ERROR(
"Invalid IValue type of " << args[1].IValue()->type()
<< " detected for shape tensor from node: " << *n);
}
} else {
new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
}

auto shuffle = ctx->net->addShuffle(*in);
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
shuffle->setReshapeDimensions(util::toDims(new_shape));
shuffle->setName(util::node_info(n).c_str());
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);

if (ctx->input_is_dynamic) {
shuffle->setInput(1, *shape_tensor);
} else {
shuffle->setReshapeDimensions(util::toDims(new_shape));
}

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
Expand Down
4 changes: 2 additions & 2 deletions core/conversion/evaluators/NodeEvaluatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ std::vector<std::string> getEvaluatorList() {
return get_evaluator_registry().GetRegisteredEvaluatorList();
}

c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
c10::optional<torch::jit::IValue> EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
auto evaluator = get_evaluator_registry().GetEvaluator(n);
return evaluator(n, args);
return evaluator(ctx, n, args);
}

void register_node_evaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
Expand Down
101 changes: 57 additions & 44 deletions core/conversion/evaluators/aten.cpp

Large diffs are not rendered by default.

252 changes: 126 additions & 126 deletions core/conversion/evaluators/eval_macros.h

Large diffs are not rendered by default.

45 changes: 44 additions & 1 deletion core/conversion/evaluators/eval_util.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "core/conversion/evaluators/eval_util.h"
#include <ATen/ATen.h>
#include "ATen/InitialTensorOptions.h"
#include "ATen/core/List.h"
Expand All @@ -6,12 +7,54 @@
#include "ATen/core/jit_type.h"
#include "c10/util/irange.h"
#include "core/util/prelude.h"
#include "torch/torch.h"

namespace torch_tensorrt {
namespace core {
namespace conversion {
namespace evaluators {

nvinfer1::ITensor* index_layer(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* input_tensor,
int64_t index) {
// index to access needs to be an at::Tensor
at::Tensor indices = torch::tensor({index}).to(torch::kI32);
auto indices_out = converters::tensor_to_const(ctx, indices);

auto gather_layer = ctx->net->addGather(*input_tensor, *indices_out, 0);
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
auto indexed_tensor = gather_layer->getOutput(0);
return indexed_tensor;
}

c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
LOG_DEBUG("Using dynamic version of aten::size evaluator");
auto in = args.at(n->input(0)).ITensorOrFreeze(ctx);
LOG_DEBUG("Input dimensions: " << in->getDimensions());
auto shape_layer = ctx->net->addShape(*in);
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
auto shape_1d_tensor = shape_layer->getOutput(0);

if (n->inputs().size() != 1) {
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
auto dim = args.at(n->input(1)).unwrapToInt();
// Handle negative axis by refering to nbDims of input Tensor
dim = dim < 0 ? dim + maxDim : dim;
LOG_DEBUG("Dimension to select: " << dim);
shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim);
}

LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());

auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(shape_1d_tensor);
auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));

return shape_1d_ivalue;
}

int64_t normalizeIndex(int64_t idx, int64_t list_size) {
if (idx < 0) {
// Handle negative indexing
Expand Down Expand Up @@ -128,7 +171,7 @@ void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
}

// TODO: Conditionally enable truncation based on user setting
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU) {
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device) {
// This function is basically same with the one in
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h, what different here is that Int and Float
// won't be upgraded to kDouble or kLong since we don't support these 2 types in conversion
Expand Down
9 changes: 9 additions & 0 deletions core/conversion/evaluators/eval_util.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
#pragma once

#include "core/conversion/evaluators/evaluators.h"
#include "torch/csrc/jit/ir/ir.h"

namespace torch_tensorrt {
namespace core {
namespace conversion {
namespace evaluators {

nvinfer1::ITensor* index_layer(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* input_tensor,
int64_t index);

c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args);

c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v);
at::Tensor createTensorFromList(
const torch::jit::IValue& data,
Expand Down
7 changes: 5 additions & 2 deletions core/conversion/evaluators/evaluators.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "torch/csrc/jit/ir/ir.h"

#include "core/conversion/conversionctx/ConversionCtx.h"
#include "core/conversion/converters/converter_util.h"
#include "core/conversion/tensorcontainer/TensorContainer.h"
#include "core/conversion/var/Var.h"

Expand Down Expand Up @@ -33,7 +35,8 @@ inline bool constTypesOnly(kwargs& args) {
// to use the node itself to pull out arguments.
// This means that you should iterate over node inputs vs. the args
// when writing evaluators
typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*, kwargs&)> NodeEvaluator;
typedef std::function<c10::optional<torch::jit::IValue>(ConversionCtx*, const torch::jit::Node*, kwargs&)>
NodeEvaluator;

struct EvalOptions {
std::set<c10::TypePtr> blacklisted_output_types;
Expand Down Expand Up @@ -72,7 +75,7 @@ struct EvalRegistration {
: kind(_kind), evaluator(_evaluator), options(_options){};
};

c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args);
c10::optional<torch::jit::IValue> EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args);
bool shouldEvalAtConversionTime(const torch::jit::Node* n);
std::vector<std::string> getEvaluatorList();
void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator);
Expand Down
59 changes: 38 additions & 21 deletions core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#include <limits>

#include "torch/csrc/jit/ir/ir.h"
//#include "torch/csrc/jit/ir/constants.h"
#include "ATen/core/List.h"
#include "ATen/core/functional.h"
#include "ATen/core/ivalue.h"
#include "ATen/core/stack.h"
#include "c10/util/intrusive_ptr.h"
#include "torch/csrc/jit/ir/ir.h"
#include "torch/torch.h"

#include "core/conversion/evaluators/eval_macros.h"
Expand All @@ -24,28 +23,28 @@ auto prim_registrations =
RegisterNodeEvaluators()
.evaluator(
{torch::jit::prim::Constant,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (n->output()->type()->kind() == at::FunctionType::Kind) {
return {};
}
return evaluators::toIValue(n->output());
}})
.evaluator(
{torch::jit::prim::NumToTensor,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
return evaluators::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar());
}})
.evaluator(
{torch::jit::prim::ListUnpack,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
const torch::jit::IValue* outputs = args.at(n->input()).IValue();
auto outputVec = outputs->toList().vec();
return std::move(c10::ivalue::Tuple::create(outputVec));
}})
.evaluator(
{torch::jit::prim::ListConstruct,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
const auto num_inputs = n->inputs().size();
if (constTypesOnly(args)) {
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
Expand Down Expand Up @@ -89,9 +88,8 @@ auto prim_registrations =
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
}
} else {
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
c10::TypePtr elementType = lt->getElementType();
auto list = c10::impl::GenericList(elementType);
// List would be of IValues (with ITensors embedded in them)
auto list = c10::impl::GenericList(c10::AnyType::get());
list.reserve(num_inputs);
for (auto in : n->inputs()) {
if (args.at(in).isITensor()) {
Expand All @@ -103,8 +101,27 @@ auto prim_registrations =
if (args.at(in).IValue()->isNone()) {
auto ival = torch::jit::IValue();
list.emplace_back(std::move(ival));
} else if (args.at(in).IValue()->isInt()) {
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(
ctx, torch::tensor({args.at(in).unwrapToInt()}).to(torch::kI32));
auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(itensor);
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
list.emplace_back(std::move(ival));
} else if (args.at(in).IValue()->isDouble()) {
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(
ctx, torch::tensor({args.at(in).unwrapToDouble()}).to(torch::kFloat));
auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(itensor);
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
list.emplace_back(std::move(ival));
} else {
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(
ctx, std::move(args.at(in).unwrapToTensor()));
auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(itensor);
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
list.emplace_back(std::move(ival));
}
}
}
Expand All @@ -113,7 +130,7 @@ auto prim_registrations =
}})
.evaluator(
{c10::Symbol::fromQualString("prim::dtype"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto input = args.at(n->input(0));
if (input.isITensor()) {
auto trt_dtype = input.ITensor()->getType();
Expand All @@ -136,7 +153,7 @@ auto prim_registrations =
})})
.evaluator(
{c10::Symbol::fromQualString("prim::min"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (n->inputs().size() == 1) {
auto a = args.at(n->input(0)).unwrapToIntList();
int64_t min = std::numeric_limits<int64_t>::max();
Expand Down Expand Up @@ -198,7 +215,7 @@ auto prim_registrations =
})})
.evaluator(
{c10::Symbol::fromQualString("prim::max"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (n->inputs().size() == 1) {
auto a = args.at(n->input(0)).unwrapToIntList();
int64_t max = std::numeric_limits<int64_t>::min();
Expand Down Expand Up @@ -260,7 +277,7 @@ auto prim_registrations =
})})
.evaluator(
{c10::Symbol::fromQualString("prim::shape"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
LOG_WARNING("There may be undefined behavior using dynamic shape and prim::shape");
auto tensor_var = args.at(n->input(0));
if (tensor_var.isITensor()) {
Expand All @@ -274,7 +291,7 @@ auto prim_registrations =
EvalOptions().validSchemas({"prim::shape(Tensor a) -> (int[])"})})
.evaluator(
{torch::jit::prim::TupleConstruct,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
c10::IValue tuple = c10::ivalue::Tuple::create();
std::vector<c10::IValue> elems;
for (auto in : n->inputs()) {
Expand All @@ -292,7 +309,7 @@ auto prim_registrations =
}})
.evaluator(
{torch::jit::prim::TupleIndex,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
auto tuple = args.at(n->input(0)).IValue()->toTuple();
int64_t idx = args.at(n->input(1)).IValue()->toInt();
Expand All @@ -302,24 +319,24 @@ auto prim_registrations =
EvalOptions().validSchemas({"prim::TupleIndex(Any tup, int i) -> (Any)"})})
.evaluator(
{torch::jit::prim::TupleUnpack,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
auto output = args.at(n->input()).IValue()->toTuple();
return c10::optional<torch::jit::IValue>(std::move(output));
}})
.evaluator(
{c10::Symbol::fromQualString("prim::unchecked_cast"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
return *(args.at(n->input(0)).IValue());
}})
.evaluator(
{c10::Symbol::fromQualString("prim::Uninitialized"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
return c10::IValue::uninitialized();
}})
.evaluator(
{c10::Symbol::fromQualString("prim::RaiseException"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto exception = args.at(n->input(0)).IValue();
TORCHTRT_THROW_ERROR("Error from TorchScript: " << *exception);
return {};
Expand All @@ -328,4 +345,4 @@ auto prim_registrations =
} // namespace evaluators
} // namespace conversion
} // namespace core
} // namespace torch_tensorrt
} // namespace torch_tensorrt
15 changes: 15 additions & 0 deletions core/conversion/var/Var.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Var : torch::CustomClassHolder {
c10::Scalar unwrapToScalar();
c10::List<int64_t> unwrapToIntList(c10::List<int64_t> default_val);
c10::List<int64_t> unwrapToIntList();
std::vector<nvinfer1::ITensor*> unwrapToITensorList();
c10::List<double> unwrapToDoubleList(c10::List<double> default_val);
c10::List<double> unwrapToDoubleList();
c10::List<bool> unwrapToBoolList(c10::List<bool> default_val);
Expand All @@ -59,6 +60,20 @@ class Var : torch::CustomClassHolder {
bool isIValue() const;
bool isITensor() const;
bool isNone() const;

bool isInt();
bool isDouble();
bool isBool();
bool isString();
bool isScalar();
bool isTensor();
bool isIntList();
bool isDoubleList();
bool isBoolList();
bool isTensorList();
bool isITensorList();
bool isList();

Var::Type type() const;
std::string type_name() const;

Expand Down
Loading

0 comments on commit d5edd7c

Please sign in to comment.