diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index b0e8174500..f8a26e8d77 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -68,7 +68,7 @@ c10::optional EvaluateNode(ConversionCtx* ctx, const torch:: return {}; } } - auto eval = evaluators::EvalNode(n, eval_args); + auto eval = evaluators::EvalNode(ctx, n, eval_args); return eval; } diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index 2df7e653ef..f758c0cc47 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -70,25 +70,37 @@ static auto shuffle_registrations TORCHTRT_UNUSED = auto in = args[0].ITensorOrFreeze(ctx); auto in_shape = util::toVec(in->getDimensions()); std::vector new_shape; + nvinfer1::ITensor* shape_tensor; if (ctx->input_is_dynamic) { - new_shape = util::toVec(args[1].unwrapToIntList().vec()); - int nbDynamicDims = 0; - for (size_t i = 0; i < new_shape.size(); i++) { - if (in_shape[i] == -1) - nbDynamicDims++; - } - if (nbDynamicDims > 1) { - TORCHTRT_THROW_ERROR( - "Resize is currently not supported when target shape contains more than one dynamic dimension"); + LOG_DEBUG("Using dynamic version of reshape layer"); + if (args[1].isITensorList()) { + LOG_DEBUG("Shape tensor is an ITensorList"); + auto new_shape = args[1].unwrapToITensorList(); + auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size()); + TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n); + concat_layer->setAxis(static_cast(0)); + shape_tensor = concat_layer->getOutput(0); + } else if (args[1].isIntList()) { + LOG_DEBUG("Shape tensor is an IntList"); + auto shape_vec = args[1].unwrapToIntList().vec(); + shape_tensor = tensor_to_const(ctx, torch::tensor(shape_vec).to(torch::kI32)); + } else { + LOG_ERROR( + "Invalid IValue type of " << args[1].IValue()->type() + << " detected for shape tensor from node: " << *n); } } else { new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec(); } - auto shuffle = ctx->net->addShuffle(*in); - TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); - shuffle->setReshapeDimensions(util::toDims(new_shape)); shuffle->setName(util::node_info(n).c_str()); + TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n); + + if (ctx->input_is_dynamic) { + shuffle->setInput(1, *shape_tensor); + } else { + shuffle->setReshapeDimensions(util::toDims(new_shape)); + } auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0)); LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); diff --git a/core/conversion/evaluators/NodeEvaluatorRegistry.cpp b/core/conversion/evaluators/NodeEvaluatorRegistry.cpp index 053e08a84e..36a2ff80bf 100644 --- a/core/conversion/evaluators/NodeEvaluatorRegistry.cpp +++ b/core/conversion/evaluators/NodeEvaluatorRegistry.cpp @@ -114,9 +114,9 @@ std::vector getEvaluatorList() { return get_evaluator_registry().GetRegisteredEvaluatorList(); } -c10::optional EvalNode(const torch::jit::Node* n, kwargs& args) { +c10::optional EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) { auto evaluator = get_evaluator_registry().GetEvaluator(n); - return evaluator(n, args); + return evaluator(ctx, n, args); } void register_node_evaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) { diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 1f16f0f575..d78bf9878c 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -130,7 +130,7 @@ auto aten_registrations TORCHTRT_UNUSED = {c10::Symbol::fromQualString("aten::zeros"), // aten::zeros(int[] size, *, int? dtype=None, int? layout=None, // Device? device=None, bool? pin_memory=None) -> (Tensor) - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA); // Input 1 here is the dtype @@ -145,7 +145,7 @@ auto aten_registrations TORCHTRT_UNUSED = {c10::Symbol::fromQualString("aten::ones"), // aten::ones(int[] size, *, int? dtype=None, int? layout=None, // Device? device=None, bool? pin_memory=None) -> (Tensor) - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA); // Input 1 here is the dtype @@ -160,7 +160,7 @@ auto aten_registrations TORCHTRT_UNUSED = {c10::Symbol::fromQualString("aten::full"), // aten::full(int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None, // Device? device=None, bool? pin_memory=None) -> (Tensor) - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA); // Input 2 here is the dtype @@ -176,7 +176,7 @@ auto aten_registrations TORCHTRT_UNUSED = {c10::Symbol::fromQualString("aten::full_like"), // aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, // Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> (Tensor) - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { // Override options related to layout and device for TensorRT auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA); auto input_tensor_var = args.at(n->input(0)); @@ -221,7 +221,7 @@ auto aten_registrations TORCHTRT_UNUSED = MemoryFormat? memory_format=None) -> (Tensor))SIG"})}) .evaluator( {c10::Symbol::fromQualString("aten::slice"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { c10::List list = args.at(n->input(0)).IValue()->to>(); int64_t start = 0; @@ -257,19 +257,21 @@ auto aten_registrations TORCHTRT_UNUSED = {"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})}) .evaluator( {c10::Symbol::fromQualString("aten::len"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { c10::List list = args.at(n->input(0)).IValue()->to>(); return static_cast(list.size()); }, EvalOptions().validSchemas({"aten::len.t(t[] a) -> (int)"})}) .evaluator( {c10::Symbol::fromQualString("aten::size"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - LOG_WARNING("There may be undefined behavior using dynamic shape and aten::size"); + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto tensor_var = args.at(n->input(0)); if (n->inputs().size() == 1) { if (tensor_var.isITensor()) { auto tensor = tensor_var.ITensor(); + if (ctx->input_is_dynamic) { + return dynamic_size_layer(ctx, n, args); + } return util::toVec(tensor->getDimensions()); } else if (tensor_var.IValue()->isTensor()) { auto tensor = tensor_var.unwrapToTensor(); @@ -283,6 +285,9 @@ auto aten_registrations TORCHTRT_UNUSED = } else { auto dim = args.at(n->input(1)).unwrapToInt(); if (tensor_var.isITensor()) { + if (ctx->input_is_dynamic) { + return dynamic_size_layer(ctx, n, args); + } auto tensor = tensor_var.ITensor(); auto dims = util::toVec(tensor->getDimensions()); auto nbDims = tensor->getDimensions().nbDims; @@ -314,22 +319,30 @@ auto aten_registrations TORCHTRT_UNUSED = {"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})}) .evaluator( {c10::Symbol::fromQualString("aten::__getitem__"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto list = args.at(n->input(0)).IValue()->to>(); + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto list_input = args.at(n->input(0)); auto idx = args.at(n->input(1)).unwrapToInt(); - - const int64_t list_size = list.size(); - const int64_t normalized_idx = normalizeIndex(idx, list_size); - TORCHTRT_CHECK( - normalized_idx >= 0 || normalized_idx < list_size, "List index out of range (aten::__getitem__)"); - return list.get(normalized_idx); + if (list_input.isIValue()) { + auto list = args.at(n->input(0)).IValue()->to>(); + const int64_t list_size = list.size(); + const int64_t normalized_idx = normalizeIndex(idx, list_size); + TORCHTRT_CHECK( + normalized_idx >= 0 || normalized_idx < list_size, "List index out of range (aten::__getitem__)"); + return list.get(normalized_idx); + } else if (list_input.isITensor()) { + auto indexed_tensor = index_layer(ctx, n, list_input.ITensorOrFreeze(ctx), idx); + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(indexed_tensor); + auto indexed_ivalue = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + return indexed_ivalue; + } }, EvalOptions().validSchemas({ "aten::__getitem__.t(t[](a) list, int idx) -> (t(*))", })}) .evaluator( {c10::Symbol::fromQualString("aten::append"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto list = args.at(n->input(0)).IValue()->to>(); if (args.at(n->input(1)).isITensor()) { @@ -349,7 +362,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::extend"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isList() && args.at(n->input(1)).IValue()->isList()) { c10::IValue* self_ptr = args.at(n->input(0)).IValueMut(); auto self = self_ptr->to>(); @@ -375,7 +388,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::neg"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto el = args.at(n->input(0)).unwrapToInt(); return el * -1; @@ -385,7 +398,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::add"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); @@ -411,7 +424,7 @@ auto aten_registrations TORCHTRT_UNUSED = "aten::add.str(str a, str b) -> (str)"})}) .evaluator( {c10::Symbol::fromQualString("aten::add_"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isList()) { auto a = args.at(n->input(0)).IValue()->toListRef(); auto b = args.at(n->input(1)).IValue()->toListRef(); @@ -441,7 +454,7 @@ auto aten_registrations TORCHTRT_UNUSED = EvalOptions().validSchemas({"aten::add_.t(t[](a!) self, t[] b) -> (t[])"})}) .evaluator( {c10::Symbol::fromQualString("aten::mul"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); @@ -461,7 +474,7 @@ auto aten_registrations TORCHTRT_UNUSED = {"aten::mul.int(int a, int b) -> (int)", "aten::mul.float(float a, float b) -> (float)"})}) .evaluator( {c10::Symbol::fromQualString("aten::sub"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); @@ -483,7 +496,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::Bool"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); return (bool)a; @@ -500,7 +513,7 @@ auto aten_registrations TORCHTRT_UNUSED = EvalOptions().validSchemas({"aten::Bool.int(int a) -> (bool)", "aten::Bool.float(float b) -> (bool)"})}) .evaluator( {c10::Symbol::fromQualString("aten::Float"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); return (float)a; @@ -524,7 +537,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::Int"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); return (int)a; @@ -549,7 +562,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::__not__"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto el = args.at(n->input(0)).unwrapToBool(); return !el; @@ -559,7 +572,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::__is__"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto self = args.at(n->input(0)).IValue(); auto obj = args.at(n->input(1)).IValue(); @@ -570,7 +583,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::__isnot__"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto self = args.at(n->input(0)).IValue(); auto obj = args.at(n->input(1)).IValue(); @@ -581,7 +594,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::numel"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { LOG_WARNING("There may be undefined behavior using dynamic shape and aten::numel"); auto tensor_var = args.at(n->input(0)); if (tensor_var.isITensor()) { @@ -597,7 +610,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::dim"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto tensor_var = args.at(n->input(0)); if (tensor_var.isITensor()) { auto tensor = tensor_var.ITensor(); @@ -612,7 +625,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::div"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); @@ -634,7 +647,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::floordiv"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); auto b = args.at(n->input(1)).unwrapToInt(); @@ -656,7 +669,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::floor"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto el = args.at(n->input(0)).unwrapToInt(); return static_cast(std::floor(el)); @@ -676,7 +689,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::sqrt"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).IValue()->isInt()) { auto a = args.at(n->input(0)).unwrapToInt(); return std::sqrt(static_cast(a)); @@ -696,7 +709,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::warn"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto warning = args.at(n->input(0)).IValue(); LOG_WARNING("Warning from TorchScript: " << *warning); return {}; @@ -704,7 +717,7 @@ auto aten_registrations TORCHTRT_UNUSED = EvalOptions()}) .evaluator( {c10::Symbol::fromQualString("aten::is_floating_point"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto tensor_var = args.at(n->input(0)); if (tensor_var.isITensor()) { auto tensor = tensor_var.ITensor(); @@ -721,7 +734,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::tensor"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto data = args.at(n->input(0)).IValue(); auto dtype = args.at(n->input(1)).IValue(); auto device = args.at(n->input(2)).IValue(); @@ -732,7 +745,7 @@ auto aten_registrations TORCHTRT_UNUSED = {"aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)"})}) .evaluator( {c10::Symbol::fromQualString("aten::arange"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto schema = n->maybeSchema(); TORCHTRT_CHECK(schema, "Unable to get schema for node: " << *n); auto name = schema->operator_name(); @@ -783,7 +796,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::clone"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(0)).isITensor()) { auto source_tensor = args.at(n->input(0)).ITensor(); auto tensor_holder = TensorContainer(); @@ -801,7 +814,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::copy_"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (args.at(n->input(1)).isITensor()) { auto source_tensor = args.at(n->input(1)).ITensor(); auto tensor_holder = TensorContainer(); @@ -820,7 +833,7 @@ auto aten_registrations TORCHTRT_UNUSED = })}) .evaluator( {c10::Symbol::fromQualString("aten::format"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { int64_t input_num = n->inputs().size(); std::vector stack; for (auto v : n->inputs()) { @@ -837,7 +850,7 @@ auto aten_registrations TORCHTRT_UNUSED = EvalOptions().validSchemas({"aten::format(str self, ...) -> (str)"})}) .evaluator( {c10::Symbol::fromQualString("aten::__range_length"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto lo = args.at(n->input(0)).unwrapToInt(); auto hi = args.at(n->input(1)).unwrapToInt(); auto step = args.at(n->input(2)).unwrapToInt(); @@ -856,7 +869,7 @@ auto aten_registrations TORCHTRT_UNUSED = EvalOptions().validSchemas({"aten::__range_length(int lo, int hi, int step) -> int"})}) .evaluator( {c10::Symbol::fromQualString("aten::__derive_index"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto idx = args.at(n->input(0)).unwrapToInt(); auto start = args.at(n->input(1)).unwrapToInt(); auto step = args.at(n->input(2)).unwrapToInt(); diff --git a/core/conversion/evaluators/eval_macros.h b/core/conversion/evaluators/eval_macros.h index 2bb126c1e9..5a0328663b 100644 --- a/core/conversion/evaluators/eval_macros.h +++ b/core/conversion/evaluators/eval_macros.h @@ -2,134 +2,134 @@ #include "core/conversion/evaluators/evaluators.h" -#define DEFINE_GENERIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \ - auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ - {c10::Symbol::fromQualString(node_kind), \ - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { \ - if (args.at(n->input(0)).IValue()->isInt()) { \ - auto a = args.at(n->input(0)).unwrapToInt(); \ - if (args.at(n->input(1)).IValue()->isInt()) { \ - auto b = args.at(n->input(1)).unwrapToInt(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isDouble()) { \ - auto b = args.at(n->input(1)).unwrapToDouble(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isBool()) { \ - auto b = args.at(n->input(1)).unwrapToBool(); \ - return operation; \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ - return {}; \ - } \ - } else if (args.at(n->input(0)).IValue()->isDouble()) { \ - auto a = args.at(n->input(0)).unwrapToDouble(); \ - if (args.at(n->input(1)).IValue()->isInt()) { \ - auto b = args.at(n->input(1)).unwrapToInt(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isDouble()) { \ - auto b = args.at(n->input(1)).unwrapToDouble(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isBool()) { \ - auto b = args.at(n->input(1)).unwrapToBool(); \ - return operation; \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ - return {}; \ - } \ - } else if (args.at(n->input(0)).IValue()->isBool()) { \ - auto a = args.at(n->input(0)).unwrapToBool(); \ - if (args.at(n->input(1)).IValue()->isInt()) { \ - auto b = args.at(n->input(1)).unwrapToInt(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isDouble()) { \ - auto b = args.at(n->input(1)).unwrapToDouble(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isBool()) { \ - auto b = args.at(n->input(1)).unwrapToBool(); \ - return operation; \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ - return {}; \ - } \ - } else if (args.at(n->input(0)).IValue()->isString()) { \ - auto a = args.at(n->input(0)).unwrapToString(); \ - if (args.at(n->input(1)).IValue()->isString()) { \ - auto b = args.at(n->input(1)).unwrapToString(); \ - return operation; \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ - return {}; \ - } \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \ - return {}; \ - } \ - }, \ +#define DEFINE_GENERIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \ + auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ + {c10::Symbol::fromQualString(node_kind), \ + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { \ + if (args.at(n->input(0)).IValue()->isInt()) { \ + auto a = args.at(n->input(0)).unwrapToInt(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else if (args.at(n->input(0)).IValue()->isDouble()) { \ + auto a = args.at(n->input(0)).unwrapToDouble(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else if (args.at(n->input(0)).IValue()->isBool()) { \ + auto a = args.at(n->input(0)).unwrapToBool(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else if (args.at(n->input(0)).IValue()->isString()) { \ + auto a = args.at(n->input(0)).unwrapToString(); \ + if (args.at(n->input(1)).IValue()->isString()) { \ + auto b = args.at(n->input(1)).unwrapToString(); \ + return operation; \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \ + return {}; \ + } \ + }, \ EvalOptions().validSchemas(schemas)}); -#define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \ - auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ - {c10::Symbol::fromQualString(node_kind), \ - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { \ - if (args.at(n->input(0)).IValue()->isInt()) { \ - auto a = args.at(n->input(0)).unwrapToInt(); \ - if (args.at(n->input(1)).IValue()->isInt()) { \ - auto b = args.at(n->input(1)).unwrapToInt(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isDouble()) { \ - auto b = args.at(n->input(1)).unwrapToDouble(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isBool()) { \ - auto b = args.at(n->input(1)).unwrapToBool(); \ - return operation; \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ - return {}; \ - } \ - } else if (args.at(n->input(0)).IValue()->isDouble()) { \ - auto a = args.at(n->input(0)).unwrapToDouble(); \ - if (args.at(n->input(1)).IValue()->isInt()) { \ - auto b = args.at(n->input(1)).unwrapToInt(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isDouble()) { \ - auto b = args.at(n->input(1)).unwrapToDouble(); \ - return operation; \ - } else if (args.at(n->input(1)).IValue()->isBool()) { \ - auto b = args.at(n->input(1)).unwrapToBool(); \ - return operation; \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ - return {}; \ - } \ - } else { \ - TORCHTRT_THROW_ERROR( \ - "Unimplemented data type for " \ - << node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \ - return {}; \ - } \ - }, \ +#define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \ + auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ + {c10::Symbol::fromQualString(node_kind), \ + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { \ + if (args.at(n->input(0)).IValue()->isInt()) { \ + auto a = args.at(n->input(0)).unwrapToInt(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else if (args.at(n->input(0)).IValue()->isDouble()) { \ + auto a = args.at(n->input(0)).unwrapToDouble(); \ + if (args.at(n->input(1)).IValue()->isInt()) { \ + auto b = args.at(n->input(1)).unwrapToInt(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isDouble()) { \ + auto b = args.at(n->input(1)).unwrapToDouble(); \ + return operation; \ + } else if (args.at(n->input(1)).IValue()->isBool()) { \ + auto b = args.at(n->input(1)).unwrapToBool(); \ + return operation; \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \ + return {}; \ + } \ + } else { \ + TORCHTRT_THROW_ERROR( \ + "Unimplemented data type for " \ + << node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \ + return {}; \ + } \ + }, \ EvalOptions().validSchemas(schemas)}); -#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \ - auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ - {c10::Symbol::fromQualString(node_name), \ - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { \ - auto a = args.at(n->input(0)).unwrapTo(); \ - auto b = args.at(n->input(1)).unwrapTo(); \ - return operation; \ - }, \ +#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \ + auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \ + {c10::Symbol::fromQualString(node_name), \ + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { \ + auto a = args.at(n->input(0)).unwrapTo(); \ + auto b = args.at(n->input(1)).unwrapTo(); \ + return operation; \ + }, \ EvalOptions().validSchemas(schemas)}); diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index c14f9a6714..fcd8f0c910 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -1,3 +1,4 @@ +#include "core/conversion/evaluators/eval_util.h" #include #include "ATen/InitialTensorOptions.h" #include "ATen/core/List.h" @@ -6,12 +7,54 @@ #include "ATen/core/jit_type.h" #include "c10/util/irange.h" #include "core/util/prelude.h" +#include "torch/torch.h" namespace torch_tensorrt { namespace core { namespace conversion { namespace evaluators { +nvinfer1::ITensor* index_layer( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* input_tensor, + int64_t index) { + // index to access needs to be an at::Tensor + at::Tensor indices = torch::tensor({index}).to(torch::kI32); + auto indices_out = converters::tensor_to_const(ctx, indices); + + auto gather_layer = ctx->net->addGather(*input_tensor, *indices_out, 0); + TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); + auto indexed_tensor = gather_layer->getOutput(0); + return indexed_tensor; +} + +c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) { + LOG_DEBUG("Using dynamic version of aten::size evaluator"); + auto in = args.at(n->input(0)).ITensorOrFreeze(ctx); + LOG_DEBUG("Input dimensions: " << in->getDimensions()); + auto shape_layer = ctx->net->addShape(*in); + TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); + auto shape_1d_tensor = shape_layer->getOutput(0); + + if (n->inputs().size() != 1) { + auto maxDim = static_cast(in->getDimensions().nbDims); + auto dim = args.at(n->input(1)).unwrapToInt(); + // Handle negative axis by refering to nbDims of input Tensor + dim = dim < 0 ? dim + maxDim : dim; + LOG_DEBUG("Dimension to select: " << dim); + shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim); + } + + LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions()); + + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(shape_1d_tensor); + auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + + return shape_1d_ivalue; +} + int64_t normalizeIndex(int64_t idx, int64_t list_size) { if (idx < 0) { // Handle negative indexing @@ -128,7 +171,7 @@ void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) { } // TODO: Conditionally enable truncation based on user setting -at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU) { +at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device) { // This function is basically same with the one in // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h, what different here is that Int and Float // won't be upgraded to kDouble or kLong since we don't support these 2 types in conversion diff --git a/core/conversion/evaluators/eval_util.h b/core/conversion/evaluators/eval_util.h index c63ead7461..5d0f050981 100644 --- a/core/conversion/evaluators/eval_util.h +++ b/core/conversion/evaluators/eval_util.h @@ -1,5 +1,6 @@ #pragma once +#include "core/conversion/evaluators/evaluators.h" #include "torch/csrc/jit/ir/ir.h" namespace torch_tensorrt { @@ -7,6 +8,14 @@ namespace core { namespace conversion { namespace evaluators { +nvinfer1::ITensor* index_layer( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* input_tensor, + int64_t index); + +c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args); + c10::optional toIValue(const torch::jit::Value* v); at::Tensor createTensorFromList( const torch::jit::IValue& data, diff --git a/core/conversion/evaluators/evaluators.h b/core/conversion/evaluators/evaluators.h index 2211fbc3e2..ba9610fac7 100644 --- a/core/conversion/evaluators/evaluators.h +++ b/core/conversion/evaluators/evaluators.h @@ -6,6 +6,8 @@ #include "torch/csrc/jit/ir/ir.h" +#include "core/conversion/conversionctx/ConversionCtx.h" +#include "core/conversion/converters/converter_util.h" #include "core/conversion/tensorcontainer/TensorContainer.h" #include "core/conversion/var/Var.h" @@ -33,7 +35,8 @@ inline bool constTypesOnly(kwargs& args) { // to use the node itself to pull out arguments. // This means that you should iterate over node inputs vs. the args // when writing evaluators -typedef std::function(const torch::jit::Node*, kwargs&)> NodeEvaluator; +typedef std::function(ConversionCtx*, const torch::jit::Node*, kwargs&)> + NodeEvaluator; struct EvalOptions { std::set blacklisted_output_types; @@ -72,7 +75,7 @@ struct EvalRegistration { : kind(_kind), evaluator(_evaluator), options(_options){}; }; -c10::optional EvalNode(const torch::jit::Node* n, kwargs& args); +c10::optional EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args); bool shouldEvalAtConversionTime(const torch::jit::Node* n); std::vector getEvaluatorList(); void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator); diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index 81a7bb9991..cbbc109982 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -1,12 +1,11 @@ #include -#include "torch/csrc/jit/ir/ir.h" -//#include "torch/csrc/jit/ir/constants.h" #include "ATen/core/List.h" #include "ATen/core/functional.h" #include "ATen/core/ivalue.h" #include "ATen/core/stack.h" #include "c10/util/intrusive_ptr.h" +#include "torch/csrc/jit/ir/ir.h" #include "torch/torch.h" #include "core/conversion/evaluators/eval_macros.h" @@ -24,7 +23,7 @@ auto prim_registrations = RegisterNodeEvaluators() .evaluator( {torch::jit::prim::Constant, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (n->output()->type()->kind() == at::FunctionType::Kind) { return {}; } @@ -32,12 +31,12 @@ auto prim_registrations = }}) .evaluator( {torch::jit::prim::NumToTensor, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { return evaluators::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar()); }}) .evaluator( {torch::jit::prim::ListUnpack, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map const torch::jit::IValue* outputs = args.at(n->input()).IValue(); auto outputVec = outputs->toList().vec(); @@ -45,7 +44,7 @@ auto prim_registrations = }}) .evaluator( {torch::jit::prim::ListConstruct, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { const auto num_inputs = n->inputs().size(); if (constTypesOnly(args)) { c10::ListTypePtr lt = n->output()->type()->expect(); @@ -89,9 +88,8 @@ auto prim_registrations = return c10::optional(std::move(torch::jit::IValue(list))); } } else { - c10::ListTypePtr lt = n->output()->type()->expect(); - c10::TypePtr elementType = lt->getElementType(); - auto list = c10::impl::GenericList(elementType); + // List would be of IValues (with ITensors embedded in them) + auto list = c10::impl::GenericList(c10::AnyType::get()); list.reserve(num_inputs); for (auto in : n->inputs()) { if (args.at(in).isITensor()) { @@ -103,8 +101,27 @@ auto prim_registrations = if (args.at(in).IValue()->isNone()) { auto ival = torch::jit::IValue(); list.emplace_back(std::move(ival)); + } else if (args.at(in).IValue()->isInt()) { + auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const( + ctx, torch::tensor({args.at(in).unwrapToInt()}).to(torch::kI32)); + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(itensor); + auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + list.emplace_back(std::move(ival)); + } else if (args.at(in).IValue()->isDouble()) { + auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const( + ctx, torch::tensor({args.at(in).unwrapToDouble()}).to(torch::kFloat)); + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(itensor); + auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + list.emplace_back(std::move(ival)); } else { - list.emplace_back(std::move(args.at(in).unwrapToTensor())); + auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const( + ctx, std::move(args.at(in).unwrapToTensor())); + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(itensor); + auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + list.emplace_back(std::move(ival)); } } } @@ -113,7 +130,7 @@ auto prim_registrations = }}) .evaluator( {c10::Symbol::fromQualString("prim::dtype"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto input = args.at(n->input(0)); if (input.isITensor()) { auto trt_dtype = input.ITensor()->getType(); @@ -136,7 +153,7 @@ auto prim_registrations = })}) .evaluator( {c10::Symbol::fromQualString("prim::min"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (n->inputs().size() == 1) { auto a = args.at(n->input(0)).unwrapToIntList(); int64_t min = std::numeric_limits::max(); @@ -198,7 +215,7 @@ auto prim_registrations = })}) .evaluator( {c10::Symbol::fromQualString("prim::max"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { if (n->inputs().size() == 1) { auto a = args.at(n->input(0)).unwrapToIntList(); int64_t max = std::numeric_limits::min(); @@ -260,7 +277,7 @@ auto prim_registrations = })}) .evaluator( {c10::Symbol::fromQualString("prim::shape"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { LOG_WARNING("There may be undefined behavior using dynamic shape and prim::shape"); auto tensor_var = args.at(n->input(0)); if (tensor_var.isITensor()) { @@ -274,7 +291,7 @@ auto prim_registrations = EvalOptions().validSchemas({"prim::shape(Tensor a) -> (int[])"})}) .evaluator( {torch::jit::prim::TupleConstruct, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { c10::IValue tuple = c10::ivalue::Tuple::create(); std::vector elems; for (auto in : n->inputs()) { @@ -292,7 +309,7 @@ auto prim_registrations = }}) .evaluator( {torch::jit::prim::TupleIndex, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map auto tuple = args.at(n->input(0)).IValue()->toTuple(); int64_t idx = args.at(n->input(1)).IValue()->toInt(); @@ -302,24 +319,24 @@ auto prim_registrations = EvalOptions().validSchemas({"prim::TupleIndex(Any tup, int i) -> (Any)"})}) .evaluator( {torch::jit::prim::TupleUnpack, - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map auto output = args.at(n->input()).IValue()->toTuple(); return c10::optional(std::move(output)); }}) .evaluator( {c10::Symbol::fromQualString("prim::unchecked_cast"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { return *(args.at(n->input(0)).IValue()); }}) .evaluator( {c10::Symbol::fromQualString("prim::Uninitialized"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { return c10::IValue::uninitialized(); }}) .evaluator( {c10::Symbol::fromQualString("prim::RaiseException"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional { auto exception = args.at(n->input(0)).IValue(); TORCHTRT_THROW_ERROR("Error from TorchScript: " << *exception); return {}; @@ -328,4 +345,4 @@ auto prim_registrations = } // namespace evaluators } // namespace conversion } // namespace core -} // namespace torch_tensorrt \ No newline at end of file +} // namespace torch_tensorrt diff --git a/core/conversion/var/Var.h b/core/conversion/var/Var.h index 6d7edcecde..41889cbbbb 100644 --- a/core/conversion/var/Var.h +++ b/core/conversion/var/Var.h @@ -43,6 +43,7 @@ class Var : torch::CustomClassHolder { c10::Scalar unwrapToScalar(); c10::List unwrapToIntList(c10::List default_val); c10::List unwrapToIntList(); + std::vector unwrapToITensorList(); c10::List unwrapToDoubleList(c10::List default_val); c10::List unwrapToDoubleList(); c10::List unwrapToBoolList(c10::List default_val); @@ -59,6 +60,20 @@ class Var : torch::CustomClassHolder { bool isIValue() const; bool isITensor() const; bool isNone() const; + + bool isInt(); + bool isDouble(); + bool isBool(); + bool isString(); + bool isScalar(); + bool isTensor(); + bool isIntList(); + bool isDoubleList(); + bool isBoolList(); + bool isTensorList(); + bool isITensorList(); + bool isList(); + Var::Type type() const; std::string type_name() const; diff --git a/core/conversion/var/Var_inl.h b/core/conversion/var/Var_inl.h index 13760a908c..a98519abe1 100644 --- a/core/conversion/var/Var_inl.h +++ b/core/conversion/var/Var_inl.h @@ -4,6 +4,13 @@ namespace torch_tensorrt { namespace core { namespace conversion { +#define DEFINE_IS_IVAL_TYPE(method_variant) \ + inline bool Var::is##method_variant() { \ + TORCHTRT_CHECK( \ + isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name()); \ + return ptr_.ivalue->is##method_variant(); \ + } + #define DEFINE_UNWRAP_TO(ival_type, method_variant) \ template <> \ inline ival_type Var::unwrapTo() { \ @@ -34,6 +41,18 @@ namespace conversion { return this->unwrapTo(); \ } +DEFINE_IS_IVAL_TYPE(Int) +DEFINE_IS_IVAL_TYPE(Double) +DEFINE_IS_IVAL_TYPE(Bool) +DEFINE_IS_IVAL_TYPE(String) +DEFINE_IS_IVAL_TYPE(Scalar) +DEFINE_IS_IVAL_TYPE(Tensor) +DEFINE_IS_IVAL_TYPE(IntList) +DEFINE_IS_IVAL_TYPE(DoubleList) +DEFINE_IS_IVAL_TYPE(BoolList) +DEFINE_IS_IVAL_TYPE(TensorList) +DEFINE_IS_IVAL_TYPE(List) + DEFINE_UNWRAP_TO(at::Tensor, Tensor) DEFINE_UNWRAP_TO(int64_t, Int) DEFINE_UNWRAP_TO(double, Double) diff --git a/tests/cpp/BUILD b/tests/cpp/BUILD index c34aa09372..709187e1b2 100644 --- a/tests/cpp/BUILD +++ b/tests/cpp/BUILD @@ -16,6 +16,7 @@ test_suite( ":test_compiled_modules", ":test_default_input_types", ":test_dynamic_fallback", + ":test_dynamic_size", ":test_example_tensors", ":test_module_fallback", ":test_modules_as_engines", @@ -32,6 +33,7 @@ test_suite( ":test_compiled_modules", ":test_default_input_types", ":test_dynamic_fallback", + ":test_dynamic_size", ":test_example_tensors", ":test_module_fallback", ":test_modules_as_engines", @@ -142,6 +144,18 @@ cc_test( }), ) +cc_test( + name = "test_dynamic_size", + srcs = ["test_dynamic_size.cpp"], + deps = [ + "//tests/util", + "@googletest//:gtest_main", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }), +) + cc_test( name = "test_collections", srcs = ["test_collections.cpp"], diff --git a/tests/cpp/test_dynamic_size.cpp b/tests/cpp/test_dynamic_size.cpp new file mode 100644 index 0000000000..202b4f5ddc --- /dev/null +++ b/tests/cpp/test_dynamic_size.cpp @@ -0,0 +1,91 @@ +#include +#include +#include "core/compiler.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ATenResizeDynamicShapeCorrectly) { + const auto graph = R"IR( + graph(%x : Tensor): + %3 : int = prim::Constant[value=0]() + %2 : int = prim::Constant[value=-1]() + %28 : int = aten::size(%x, %3) + %30 : int[] = prim::ListConstruct(%28, %2) + %6 : Tensor = aten::reshape(%x, %30) + return (%6))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 3, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenResizeDynamicInputCorrectly) { + const auto graph = R"IR( + graph(%x : Tensor): + %2 : int[] = prim::Constant[value=[-1, 4, 64]]() + %3 : Tensor = aten::reshape(%x, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 16, 16}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenResizeGetItemDynShapeCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %3 : int = prim::Constant[value=-1]() + %2 : int = prim::Constant[value=0]() + %size.1 : int[] = aten::size(%x.1) + %37 : int = aten::__getitem__(%size.1, %2) + %39 : int[] = prim::ListConstruct(%37, %3) + %7 : Tensor = aten::reshape(%x.1, %39) + return (%7))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 16, 16}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} \ No newline at end of file