diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 8025a1086b..63c7713d01 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -271,37 +271,229 @@ auto select_registrations TORCHTRT_UNUSED = auto ts = args[1].IValue()->toListRef(); std::vector tensors; - for (auto t : ts) { + std::vector adv_idx_indices; + for (auto i = 0; i < ts.size(); i++) { + auto t = ts[i]; if (t.isTensor()) { - auto torch_tensor = t.toTensor(); + auto torch_tensor = t.toTensor().to(torch::kInt32); tensors.push_back(tensor_to_const(ctx, torch_tensor)); + adv_idx_indices.push_back(i); } else { - auto cont = t.toCustomClass(); - tensors.push_back(cont->tensor()); + // IValue + if (!t.isNone()) { + adv_idx_indices.push_back(i); + auto cont = t.toCustomClass(); + // Set datatype for indices tensor to INT32 + auto identity = ctx->net->addIdentity(*cont->tensor()); + identity->setOutputType(0, nvinfer1::DataType::kINT32); + tensors.push_back(identity->getOutput(0)); + } } } - // In TorchScript, aten::index.Tensor indexes the self tensor along its each dimension by several - // indexes. In this version of Torch-TensorRT, it can only receive one index tensor which means it only - // indexes the self tensor along dimension 0. - TORCHTRT_CHECK( - tensors.size() == 1, - "In this version of Torch-TensorRT, aten::index.Tensor can only receive one index tensor which means it only indexes the self tensor along dimension 0."); - auto indicesTensor = tensors[0]; - // Set datatype for indices tensor to INT32 - auto identity = ctx->net->addIdentity(*indicesTensor); - identity->setOutputType(0, nvinfer1::DataType::kINT32); - indicesTensor = identity->getOutput(0); + if (tensors.size() == 0) { + auto identity_out = ctx->net->addIdentity(*in)->getOutput(0); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity_out); + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + } else if (tensors.size() == 1) { + auto indicesTensor = tensors[0]; + // Set datatype for indices tensor to INT32 + auto identity = ctx->net->addIdentity(*indicesTensor); + identity->setOutputType(0, nvinfer1::DataType::kINT32); + indicesTensor = identity->getOutput(0); + + // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices + // from + auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0); + TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); + auto gather_out = gather_layer->getOutput(0); + + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out); + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + } else { + auto inDims = in->getDimensions(); + int rank = inDims.nbDims; + LOG_WARNING("If indices include negative values, the exported graph will produce incorrect results."); + int adv_idx_count = adv_idx_indices.size(); + auto in_shape_itensor = ctx->net->addShape(*in)->getOutput(0); + + std::vector dim_tensor_list; + for (int i = 0; i < rank; i++) { + auto dim_tensor = + ctx->net + ->addGather(*in_shape_itensor, *tensor_to_const(ctx, torch::tensor({i}, torch::kInt32)), 0) + ->getOutput(0); + dim_tensor_list.push_back(dim_tensor); + } - // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices - // from - auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0); - TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); - auto gather_out = gather_layer->getOutput(0); + // t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n], + // where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes + // for ":". + auto in_transpose_layer = ctx->net->addShuffle(*in); + TORCHTRT_CHECK(in_transpose_layer, "Unable to create shuffle layer from node: " << *n); + nvinfer1::Permutation permute; + std::vector new_order; + for (int i = 0; i < adv_idx_count; i++) { + new_order.push_back(adv_idx_indices[i]); + } + for (int i = 0; i < rank; i++) { + if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) { + new_order.push_back(i); + } + } + std::copy(new_order.begin(), new_order.end(), permute.order); + in_transpose_layer->setSecondTranspose(permute); + auto shuffle_out = in_transpose_layer->getOutput(0); + + // t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] -> t: [x_1*x_2* ...*x_m, y_1*y_2* ...*y_n] + nvinfer1::ITensor* flatten_tensor = NULL; + { + auto shuffle_shape_tensor = ctx->net->addShape(*shuffle_out)->getOutput(0); + auto d0 = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32)); + for (int i = 0; i < adv_idx_count; i++) { + auto dim_tensor = + ctx->net + ->addGather( + *shuffle_shape_tensor, *tensor_to_const(ctx, torch::tensor({i}, torch::kInt32)), 0) + ->getOutput(0); + d0 = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPROD, + d0, + dim_tensor, + std::string("compute_dim0_") + std::to_string(i)) + ->getOutput(0); + } - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out); + auto d1 = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32)); + for (int i = adv_idx_count; i < rank; i++) { + auto dim_tensor = + ctx->net + ->addGather( + *shuffle_shape_tensor, *tensor_to_const(ctx, torch::tensor({i}, torch::kInt32)), 0) + ->getOutput(0); + d1 = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPROD, + d1, + dim_tensor, + std::string("compute_dim1_") + std::to_string(i)) + ->getOutput(0); + } - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + std::vector concat_tensors; + concat_tensors.push_back(d0); + concat_tensors.push_back(d1); + auto concat_layer = ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size()); + + auto shuffle = ctx->net->addShuffle(*shuffle_out); + shuffle->setInput(1, *concat_layer->getOutput(0)); + flatten_tensor = shuffle->getOutput(0); + LOG_DEBUG(flatten_tensor->getDimensions()); + } + + // tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the + // j dimension of input x. + nvinfer1::ITensor* multiplier = dim_tensor_list[adv_idx_indices[adv_idx_count - 1]]; + nvinfer1::ITensor* cum_adv_index = tensors[adv_idx_count - 1]; + for (int i = adv_idx_count - 2; i >= 0; i--) { + nvinfer1::ITensor* adv_index = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPROD, + tensors[i], + multiplier, + std::string("adv_index_") + std::to_string(i)) + ->getOutput(0); + cum_adv_index = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kSUM, + cum_adv_index, + adv_index, + std::string("cum_adv_index_") + std::to_string(i)) + ->getOutput(0); + multiplier = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPROD, + multiplier, + dim_tensor_list[adv_idx_indices[i]], + std::string("multiplier_") + std::to_string(i)) + ->getOutput(0); + } + + // perform gather + auto gather_out = ctx->net->addGather(*flatten_tensor, *cum_adv_index, 0)->getOutput(0); + + nvinfer1::ITensor* reshape_output = NULL; + { + auto cum_adv_index_shape_tensor = ctx->net->addShape(*cum_adv_index)->getOutput(0); + // check if all advanced indices are consecutive. + if (adv_idx_count == (adv_idx_indices[adv_idx_count - 1] - adv_idx_indices[0] + 1)) { + // unfold regular index axes + std::vector concat_tensors; + concat_tensors.push_back(tensor_to_const(ctx, torch::tensor({-1}, torch::kInt32))); + for (int i = 0; i < rank; i++) { + if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) { + nvinfer1::ITensor* current_dim = dim_tensor_list[i]; + concat_tensors.push_back(current_dim); + } + } + auto concat_layer = ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size()); + auto regular_index_shuffle_layer = ctx->net->addShuffle(*gather_out); + regular_index_shuffle_layer->setInput(1, *concat_layer->getOutput(0)); + auto unfold_tensor = regular_index_shuffle_layer->getOutput(0); + + // Transpose folded advanced indexed axis to its original location. + auto transpose_advanced_shuffle_layer = ctx->net->addShuffle(*unfold_tensor); + nvinfer1::Permutation permute; + std::vector new_order; + for (int i = 1; i < adv_idx_indices[0] + 1; i++) { + new_order.push_back(i); + } + new_order.push_back(0); + for (int i = adv_idx_indices[0] + 1; i < rank - adv_idx_count + 1; i++) { + new_order.push_back(i); + } + std::copy(new_order.begin(), new_order.end(), permute.order); + transpose_advanced_shuffle_layer->setSecondTranspose(permute); + auto shuffle_out = transpose_advanced_shuffle_layer->getOutput(0); + + // unfold advanced index axes + std::vector concat_final_tensors; + for (int i = 0; i < adv_idx_indices[0]; i++) { + nvinfer1::ITensor* current_dim = dim_tensor_list[i]; + concat_final_tensors.push_back(current_dim); + } + concat_final_tensors.push_back(cum_adv_index_shape_tensor); + for (int i = adv_idx_indices[0]; i < rank; i++) { + if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) { + nvinfer1::ITensor* current_dim = dim_tensor_list[i]; + concat_final_tensors.push_back(current_dim); + } + } + auto concat_final_shape_layer = + ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size()); + auto unfold_advanced_shuffle_layer = ctx->net->addShuffle(*shuffle_out); + unfold_advanced_shuffle_layer->setInput(1, *concat_final_shape_layer->getOutput(0)); + reshape_output = unfold_advanced_shuffle_layer->getOutput(0); + } else { + std::vector concat_tensors; + concat_tensors.push_back(cum_adv_index_shape_tensor); + for (int i = 0; i < rank; i++) { + if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) { + nvinfer1::ITensor* current_dim = dim_tensor_list[i]; + concat_tensors.push_back(current_dim); + } + } + auto concat_layer = ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size()); + auto shuffle_layer = ctx->net->addShuffle(*gather_out); + shuffle_layer->setInput(1, *concat_layer->getOutput(0)); + reshape_output = shuffle_layer->getOutput(0); + } + } + + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], reshape_output); + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + } return true; }}) .pattern( diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index 2245ca05dc..81a7bb9991 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -100,7 +100,12 @@ auto prim_registrations = auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); list.emplace_back(std::move(ival)); } else { - list.emplace_back(std::move(args.at(in).unwrapToTensor())); + if (args.at(in).IValue()->isNone()) { + auto ival = torch::jit::IValue(); + list.emplace_back(std::move(ival)); + } else { + list.emplace_back(std::move(args.at(in).unwrapToTensor())); + } } } return c10::optional(std::move(torch::jit::IValue(list))); diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index e9405c0155..c84446ad68 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -776,7 +776,7 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } -TEST(Converters, ATenIndexTensorConvertsCorrectly) { +TEST(Converters, ATenIndexTensorOneIndiceConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor, %index : Tensor): @@ -802,6 +802,125 @@ TEST(Converters, ATenIndexTensorConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } +TEST(Converters, ATenIndexTensorFullIndicesConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor, + %index1 : Tensor, + %index2 : Tensor): + %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); + auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); + auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); + auto index2 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); + auto index0_trt = index0.to(torch::kInt32); + auto index1_trt = index1.to(torch::kInt32); + auto index2_trt = index2.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenIndexTensorIdx0Idx1NoneConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor, + %index1 : Tensor): + %5 : NoneType = prim::Constant() + %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %5) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); + auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); + auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); + auto index0_trt = index0.to(torch::kInt32); + auto index1_trt = index1.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); + LOG_DEBUG(trt_results); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenIndexTensorIdx0NoneIdx1ConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor, + %index1 : Tensor): + %5 : NoneType = prim::Constant() + %18 : Tensor?[] = prim::ListConstruct(%index0, %5, %index1) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); + auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); + auto index1 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); + auto index0_trt = index0.to(torch::kInt32); + auto index1_trt = index1.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenIndexTensorNoneIdx0Idx1ConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor, + %index1 : Tensor): + %5 : NoneType = prim::Constant() + %18 : Tensor?[] = prim::ListConstruct(%5, %index0, %index1) + %19 : Tensor = aten::index(%x.1, %18) + return (%19))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); + auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); + auto index1 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); + auto index0_trt = index0.to(torch::kInt32); + auto index1_trt = index1.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + TEST(Converters, ATenUnbindConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor):