Skip to content

Commit

Permalink
support multiple indices for aten::index.Tensor (#1309)
Browse files Browse the repository at this point in the history
Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>

Signed-off-by: Ruoqian Guo <ruoqiang@nvidia.com>
  • Loading branch information
ruoqianguo authored Aug 25, 2022
1 parent 4d32d47 commit 22c0e17
Show file tree
Hide file tree
Showing 3 changed files with 340 additions and 24 deletions.
236 changes: 214 additions & 22 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,37 +271,229 @@ auto select_registrations TORCHTRT_UNUSED =
auto ts = args[1].IValue()->toListRef();

std::vector<nvinfer1::ITensor*> tensors;
for (auto t : ts) {
std::vector<int32_t> adv_idx_indices;
for (auto i = 0; i < ts.size(); i++) {
auto t = ts[i];
if (t.isTensor()) {
auto torch_tensor = t.toTensor();
auto torch_tensor = t.toTensor().to(torch::kInt32);
tensors.push_back(tensor_to_const(ctx, torch_tensor));
adv_idx_indices.push_back(i);
} else {
auto cont = t.toCustomClass<TensorContainer>();
tensors.push_back(cont->tensor());
// IValue
if (!t.isNone()) {
adv_idx_indices.push_back(i);
auto cont = t.toCustomClass<TensorContainer>();
// Set datatype for indices tensor to INT32
auto identity = ctx->net->addIdentity(*cont->tensor());
identity->setOutputType(0, nvinfer1::DataType::kINT32);
tensors.push_back(identity->getOutput(0));
}
}
}

// In TorchScript, aten::index.Tensor indexes the self tensor along its each dimension by several
// indexes. In this version of Torch-TensorRT, it can only receive one index tensor which means it only
// indexes the self tensor along dimension 0.
TORCHTRT_CHECK(
tensors.size() == 1,
"In this version of Torch-TensorRT, aten::index.Tensor can only receive one index tensor which means it only indexes the self tensor along dimension 0.");
auto indicesTensor = tensors[0];
// Set datatype for indices tensor to INT32
auto identity = ctx->net->addIdentity(*indicesTensor);
identity->setOutputType(0, nvinfer1::DataType::kINT32);
indicesTensor = identity->getOutput(0);
if (tensors.size() == 0) {
auto identity_out = ctx->net->addIdentity(*in)->getOutput(0);
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity_out);
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
} else if (tensors.size() == 1) {
auto indicesTensor = tensors[0];
// Set datatype for indices tensor to INT32
auto identity = ctx->net->addIdentity(*indicesTensor);
identity->setOutputType(0, nvinfer1::DataType::kINT32);
indicesTensor = identity->getOutput(0);

// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
// from
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
auto gather_out = gather_layer->getOutput(0);

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
} else {
auto inDims = in->getDimensions();
int rank = inDims.nbDims;
LOG_WARNING("If indices include negative values, the exported graph will produce incorrect results.");
int adv_idx_count = adv_idx_indices.size();
auto in_shape_itensor = ctx->net->addShape(*in)->getOutput(0);

std::vector<nvinfer1::ITensor*> dim_tensor_list;
for (int i = 0; i < rank; i++) {
auto dim_tensor =
ctx->net
->addGather(*in_shape_itensor, *tensor_to_const(ctx, torch::tensor({i}, torch::kInt32)), 0)
->getOutput(0);
dim_tensor_list.push_back(dim_tensor);
}

// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
// from
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
auto gather_out = gather_layer->getOutput(0);
// t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
// where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
// for ":".
auto in_transpose_layer = ctx->net->addShuffle(*in);
TORCHTRT_CHECK(in_transpose_layer, "Unable to create shuffle layer from node: " << *n);
nvinfer1::Permutation permute;
std::vector<int32_t> new_order;
for (int i = 0; i < adv_idx_count; i++) {
new_order.push_back(adv_idx_indices[i]);
}
for (int i = 0; i < rank; i++) {
if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) {
new_order.push_back(i);
}
}
std::copy(new_order.begin(), new_order.end(), permute.order);
in_transpose_layer->setSecondTranspose(permute);
auto shuffle_out = in_transpose_layer->getOutput(0);

// t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] -> t: [x_1*x_2* ...*x_m, y_1*y_2* ...*y_n]
nvinfer1::ITensor* flatten_tensor = NULL;
{
auto shuffle_shape_tensor = ctx->net->addShape(*shuffle_out)->getOutput(0);
auto d0 = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32));
for (int i = 0; i < adv_idx_count; i++) {
auto dim_tensor =
ctx->net
->addGather(
*shuffle_shape_tensor, *tensor_to_const(ctx, torch::tensor({i}, torch::kInt32)), 0)
->getOutput(0);
d0 = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
d0,
dim_tensor,
std::string("compute_dim0_") + std::to_string(i))
->getOutput(0);
}

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);
auto d1 = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32));
for (int i = adv_idx_count; i < rank; i++) {
auto dim_tensor =
ctx->net
->addGather(
*shuffle_shape_tensor, *tensor_to_const(ctx, torch::tensor({i}, torch::kInt32)), 0)
->getOutput(0);
d1 = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
d1,
dim_tensor,
std::string("compute_dim1_") + std::to_string(i))
->getOutput(0);
}

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
std::vector<nvinfer1::ITensor*> concat_tensors;
concat_tensors.push_back(d0);
concat_tensors.push_back(d1);
auto concat_layer = ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());

auto shuffle = ctx->net->addShuffle(*shuffle_out);
shuffle->setInput(1, *concat_layer->getOutput(0));
flatten_tensor = shuffle->getOutput(0);
LOG_DEBUG(flatten_tensor->getDimensions());
}

// tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
// j dimension of input x.
nvinfer1::ITensor* multiplier = dim_tensor_list[adv_idx_indices[adv_idx_count - 1]];
nvinfer1::ITensor* cum_adv_index = tensors[adv_idx_count - 1];
for (int i = adv_idx_count - 2; i >= 0; i--) {
nvinfer1::ITensor* adv_index = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
tensors[i],
multiplier,
std::string("adv_index_") + std::to_string(i))
->getOutput(0);
cum_adv_index = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kSUM,
cum_adv_index,
adv_index,
std::string("cum_adv_index_") + std::to_string(i))
->getOutput(0);
multiplier = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
multiplier,
dim_tensor_list[adv_idx_indices[i]],
std::string("multiplier_") + std::to_string(i))
->getOutput(0);
}

// perform gather
auto gather_out = ctx->net->addGather(*flatten_tensor, *cum_adv_index, 0)->getOutput(0);

nvinfer1::ITensor* reshape_output = NULL;
{
auto cum_adv_index_shape_tensor = ctx->net->addShape(*cum_adv_index)->getOutput(0);
// check if all advanced indices are consecutive.
if (adv_idx_count == (adv_idx_indices[adv_idx_count - 1] - adv_idx_indices[0] + 1)) {
// unfold regular index axes
std::vector<nvinfer1::ITensor*> concat_tensors;
concat_tensors.push_back(tensor_to_const(ctx, torch::tensor({-1}, torch::kInt32)));
for (int i = 0; i < rank; i++) {
if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) {
nvinfer1::ITensor* current_dim = dim_tensor_list[i];
concat_tensors.push_back(current_dim);
}
}
auto concat_layer = ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());
auto regular_index_shuffle_layer = ctx->net->addShuffle(*gather_out);
regular_index_shuffle_layer->setInput(1, *concat_layer->getOutput(0));
auto unfold_tensor = regular_index_shuffle_layer->getOutput(0);

// Transpose folded advanced indexed axis to its original location.
auto transpose_advanced_shuffle_layer = ctx->net->addShuffle(*unfold_tensor);
nvinfer1::Permutation permute;
std::vector<int32_t> new_order;
for (int i = 1; i < adv_idx_indices[0] + 1; i++) {
new_order.push_back(i);
}
new_order.push_back(0);
for (int i = adv_idx_indices[0] + 1; i < rank - adv_idx_count + 1; i++) {
new_order.push_back(i);
}
std::copy(new_order.begin(), new_order.end(), permute.order);
transpose_advanced_shuffle_layer->setSecondTranspose(permute);
auto shuffle_out = transpose_advanced_shuffle_layer->getOutput(0);

// unfold advanced index axes
std::vector<nvinfer1::ITensor*> concat_final_tensors;
for (int i = 0; i < adv_idx_indices[0]; i++) {
nvinfer1::ITensor* current_dim = dim_tensor_list[i];
concat_final_tensors.push_back(current_dim);
}
concat_final_tensors.push_back(cum_adv_index_shape_tensor);
for (int i = adv_idx_indices[0]; i < rank; i++) {
if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) {
nvinfer1::ITensor* current_dim = dim_tensor_list[i];
concat_final_tensors.push_back(current_dim);
}
}
auto concat_final_shape_layer =
ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());
auto unfold_advanced_shuffle_layer = ctx->net->addShuffle(*shuffle_out);
unfold_advanced_shuffle_layer->setInput(1, *concat_final_shape_layer->getOutput(0));
reshape_output = unfold_advanced_shuffle_layer->getOutput(0);
} else {
std::vector<nvinfer1::ITensor*> concat_tensors;
concat_tensors.push_back(cum_adv_index_shape_tensor);
for (int i = 0; i < rank; i++) {
if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) {
nvinfer1::ITensor* current_dim = dim_tensor_list[i];
concat_tensors.push_back(current_dim);
}
}
auto concat_layer = ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());
auto shuffle_layer = ctx->net->addShuffle(*gather_out);
shuffle_layer->setInput(1, *concat_layer->getOutput(0));
reshape_output = shuffle_layer->getOutput(0);
}
}

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], reshape_output);
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
}
return true;
}})
.pattern(
Expand Down
7 changes: 6 additions & 1 deletion core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ auto prim_registrations =
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
list.emplace_back(std::move(ival));
} else {
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
if (args.at(in).IValue()->isNone()) {
auto ival = torch::jit::IValue();
list.emplace_back(std::move(ival));
} else {
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
}
}
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
Expand Down
Loading

0 comments on commit 22c0e17

Please sign in to comment.