Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple indices for aten::index.Tensor #1309

Merged
merged 1 commit into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 214 additions & 22 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,37 +271,229 @@ auto select_registrations TORCHTRT_UNUSED =
auto ts = args[1].IValue()->toListRef();

std::vector<nvinfer1::ITensor*> tensors;
for (auto t : ts) {
std::vector<int32_t> adv_idx_indices;
for (auto i = 0; i < ts.size(); i++) {
auto t = ts[i];
if (t.isTensor()) {
auto torch_tensor = t.toTensor();
auto torch_tensor = t.toTensor().to(torch::kInt32);
tensors.push_back(tensor_to_const(ctx, torch_tensor));
adv_idx_indices.push_back(i);
} else {
auto cont = t.toCustomClass<TensorContainer>();
tensors.push_back(cont->tensor());
// IValue
if (!t.isNone()) {
adv_idx_indices.push_back(i);
auto cont = t.toCustomClass<TensorContainer>();
// Set datatype for indices tensor to INT32
auto identity = ctx->net->addIdentity(*cont->tensor());
identity->setOutputType(0, nvinfer1::DataType::kINT32);
tensors.push_back(identity->getOutput(0));
}
}
}

// In TorchScript, aten::index.Tensor indexes the self tensor along its each dimension by several
// indexes. In this version of Torch-TensorRT, it can only receive one index tensor which means it only
// indexes the self tensor along dimension 0.
TORCHTRT_CHECK(
tensors.size() == 1,
"In this version of Torch-TensorRT, aten::index.Tensor can only receive one index tensor which means it only indexes the self tensor along dimension 0.");
auto indicesTensor = tensors[0];
// Set datatype for indices tensor to INT32
auto identity = ctx->net->addIdentity(*indicesTensor);
identity->setOutputType(0, nvinfer1::DataType::kINT32);
indicesTensor = identity->getOutput(0);
if (tensors.size() == 0) {
auto identity_out = ctx->net->addIdentity(*in)->getOutput(0);
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], identity_out);
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
} else if (tensors.size() == 1) {
auto indicesTensor = tensors[0];
// Set datatype for indices tensor to INT32
auto identity = ctx->net->addIdentity(*indicesTensor);
identity->setOutputType(0, nvinfer1::DataType::kINT32);
indicesTensor = identity->getOutput(0);

// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
// from
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
auto gather_out = gather_layer->getOutput(0);

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
} else {
auto inDims = in->getDimensions();
int rank = inDims.nbDims;
LOG_WARNING("If indices include negative values, the exported graph will produce incorrect results.");
int adv_idx_count = adv_idx_indices.size();
auto in_shape_itensor = ctx->net->addShape(*in)->getOutput(0);

std::vector<nvinfer1::ITensor*> dim_tensor_list;
for (int i = 0; i < rank; i++) {
auto dim_tensor =
ctx->net
->addGather(*in_shape_itensor, *tensor_to_const(ctx, torch::tensor({i}, torch::kInt32)), 0)
->getOutput(0);
dim_tensor_list.push_back(dim_tensor);
}

// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
// from
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
auto gather_out = gather_layer->getOutput(0);
// t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
// where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
// for ":".
auto in_transpose_layer = ctx->net->addShuffle(*in);
TORCHTRT_CHECK(in_transpose_layer, "Unable to create shuffle layer from node: " << *n);
nvinfer1::Permutation permute;
std::vector<int32_t> new_order;
for (int i = 0; i < adv_idx_count; i++) {
new_order.push_back(adv_idx_indices[i]);
}
for (int i = 0; i < rank; i++) {
if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) {
new_order.push_back(i);
}
}
std::copy(new_order.begin(), new_order.end(), permute.order);
in_transpose_layer->setSecondTranspose(permute);
auto shuffle_out = in_transpose_layer->getOutput(0);

// t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] -> t: [x_1*x_2* ...*x_m, y_1*y_2* ...*y_n]
nvinfer1::ITensor* flatten_tensor = NULL;
{
auto shuffle_shape_tensor = ctx->net->addShape(*shuffle_out)->getOutput(0);
auto d0 = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32));
for (int i = 0; i < adv_idx_count; i++) {
auto dim_tensor =
ctx->net
->addGather(
*shuffle_shape_tensor, *tensor_to_const(ctx, torch::tensor({i}, torch::kInt32)), 0)
->getOutput(0);
d0 = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
d0,
dim_tensor,
std::string("compute_dim0_") + std::to_string(i))
->getOutput(0);
}

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);
auto d1 = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32));
for (int i = adv_idx_count; i < rank; i++) {
auto dim_tensor =
ctx->net
->addGather(
*shuffle_shape_tensor, *tensor_to_const(ctx, torch::tensor({i}, torch::kInt32)), 0)
->getOutput(0);
d1 = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
d1,
dim_tensor,
std::string("compute_dim1_") + std::to_string(i))
->getOutput(0);
}

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
std::vector<nvinfer1::ITensor*> concat_tensors;
concat_tensors.push_back(d0);
concat_tensors.push_back(d1);
auto concat_layer = ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());

auto shuffle = ctx->net->addShuffle(*shuffle_out);
shuffle->setInput(1, *concat_layer->getOutput(0));
flatten_tensor = shuffle->getOutput(0);
LOG_DEBUG(flatten_tensor->getDimensions());
}

// tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
// j dimension of input x.
nvinfer1::ITensor* multiplier = dim_tensor_list[adv_idx_indices[adv_idx_count - 1]];
nvinfer1::ITensor* cum_adv_index = tensors[adv_idx_count - 1];
for (int i = adv_idx_count - 2; i >= 0; i--) {
nvinfer1::ITensor* adv_index = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
tensors[i],
multiplier,
std::string("adv_index_") + std::to_string(i))
->getOutput(0);
cum_adv_index = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kSUM,
cum_adv_index,
adv_index,
std::string("cum_adv_index_") + std::to_string(i))
->getOutput(0);
multiplier = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
multiplier,
dim_tensor_list[adv_idx_indices[i]],
std::string("multiplier_") + std::to_string(i))
->getOutput(0);
}

// perform gather
auto gather_out = ctx->net->addGather(*flatten_tensor, *cum_adv_index, 0)->getOutput(0);

nvinfer1::ITensor* reshape_output = NULL;
{
auto cum_adv_index_shape_tensor = ctx->net->addShape(*cum_adv_index)->getOutput(0);
// check if all advanced indices are consecutive.
if (adv_idx_count == (adv_idx_indices[adv_idx_count - 1] - adv_idx_indices[0] + 1)) {
// unfold regular index axes
std::vector<nvinfer1::ITensor*> concat_tensors;
concat_tensors.push_back(tensor_to_const(ctx, torch::tensor({-1}, torch::kInt32)));
for (int i = 0; i < rank; i++) {
if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) {
nvinfer1::ITensor* current_dim = dim_tensor_list[i];
concat_tensors.push_back(current_dim);
}
}
auto concat_layer = ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());
auto regular_index_shuffle_layer = ctx->net->addShuffle(*gather_out);
regular_index_shuffle_layer->setInput(1, *concat_layer->getOutput(0));
auto unfold_tensor = regular_index_shuffle_layer->getOutput(0);

// Transpose folded advanced indexed axis to its original location.
auto transpose_advanced_shuffle_layer = ctx->net->addShuffle(*unfold_tensor);
nvinfer1::Permutation permute;
std::vector<int32_t> new_order;
for (int i = 1; i < adv_idx_indices[0] + 1; i++) {
new_order.push_back(i);
}
new_order.push_back(0);
for (int i = adv_idx_indices[0] + 1; i < rank - adv_idx_count + 1; i++) {
new_order.push_back(i);
}
std::copy(new_order.begin(), new_order.end(), permute.order);
transpose_advanced_shuffle_layer->setSecondTranspose(permute);
auto shuffle_out = transpose_advanced_shuffle_layer->getOutput(0);

// unfold advanced index axes
std::vector<nvinfer1::ITensor*> concat_final_tensors;
for (int i = 0; i < adv_idx_indices[0]; i++) {
nvinfer1::ITensor* current_dim = dim_tensor_list[i];
concat_final_tensors.push_back(current_dim);
}
concat_final_tensors.push_back(cum_adv_index_shape_tensor);
for (int i = adv_idx_indices[0]; i < rank; i++) {
if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) {
nvinfer1::ITensor* current_dim = dim_tensor_list[i];
concat_final_tensors.push_back(current_dim);
}
}
auto concat_final_shape_layer =
ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());
auto unfold_advanced_shuffle_layer = ctx->net->addShuffle(*shuffle_out);
unfold_advanced_shuffle_layer->setInput(1, *concat_final_shape_layer->getOutput(0));
reshape_output = unfold_advanced_shuffle_layer->getOutput(0);
} else {
std::vector<nvinfer1::ITensor*> concat_tensors;
concat_tensors.push_back(cum_adv_index_shape_tensor);
for (int i = 0; i < rank; i++) {
if (std::find(adv_idx_indices.begin(), adv_idx_indices.end(), i) == adv_idx_indices.end()) {
nvinfer1::ITensor* current_dim = dim_tensor_list[i];
concat_tensors.push_back(current_dim);
}
}
auto concat_layer = ctx->net->addConcatenation(concat_tensors.data(), concat_tensors.size());
auto shuffle_layer = ctx->net->addShuffle(*gather_out);
shuffle_layer->setInput(1, *concat_layer->getOutput(0));
reshape_output = shuffle_layer->getOutput(0);
}
}

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], reshape_output);
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
}
return true;
}})
.pattern(
Expand Down
7 changes: 6 additions & 1 deletion core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ auto prim_registrations =
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
list.emplace_back(std::move(ival));
} else {
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
if (args.at(in).IValue()->isNone()) {
auto ival = torch::jit::IValue();
list.emplace_back(std::move(ival));
} else {
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
}
}
}
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
Expand Down
Loading