Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numpy changes for aten::index converter #2396

Merged
merged 6 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ def cast_int_int_div_trt_tensor(


def broadcastable(
a: TRTTensor,
b: TRTTensor,
a: Union[TRTTensor, np.ndarray], b: Union[TRTTensor, np.ndarray]
) -> bool:
"Check if two tensors are broadcastable according to torch rules"
a_shape = tuple(a.shape)
Expand Down
103 changes: 64 additions & 39 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand Down Expand Up @@ -80,23 +81,34 @@ def index(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
index: Union[TRTTensor, Sequence[TRTTensor]],
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
# _LOGGER.debug(f"The index shape is {index.shape}")
# check if the input is dynamic
dynamic_shape = has_dynamic_shape(input.shape)

# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
# If any is not this flag will be set to False
_LOGGER.debug(
f"Determining whether aten.index constant-index optimization can be invoked"
)
is_numpy = all(
isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
)
# here we need to check if all the index are broadcastable
# if no, then we need to broadcast
last_index = None
for i, ind in enumerate(index):
if ind is not None:
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
adv_indx_indices.append(i)
# torch.nn.parameter.Parameter=> torch.Tensor
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
# torch.nn.parameter.Parameter=> numpy array
# numpy array is kept as numpy
# other cases are kept as TRTTensor
if is_numpy:
ind = to_numpy(ind)
else:
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
if last_index is not None:
assert broadcastable(
ind, last_index
Expand All @@ -110,8 +122,9 @@ def index(
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
return identity_layer.get_output(0)
elif len(tensor_indices) == 1:
# This case works
indices_tensor = tensor_indices[0]
indices_tensor = get_trt_tensor(
ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
)
index = adv_indx_indices[0]
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
gather_layer = ctx.net.add_gather(input, indices_tensor, index)
Expand Down Expand Up @@ -150,6 +163,7 @@ def index(
if i not in adv_indx_indices:
new_order.append(i)
_LOGGER.debug(f"The new transpose order is {new_order}")

transpose_layer.second_transpose = tuple(new_order)
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
transpose_tensor = transpose_layer.get_output(0)
Expand All @@ -175,47 +189,58 @@ def index(
concat_tensor = concat_tensor_layer.get_output(0)

reshape_layer = ctx.net.add_shuffle(transpose_tensor)
# check this
reshape_layer.set_input(1, concat_tensor)
flatten_tensor = reshape_layer.get_output(0)

_LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}")

# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
# // j dimension of input x.
multiplier = get_trt_tensor(
ctx,
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
name + "_dim_last",
)
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_intermediate_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
tensor_indices[i],
if is_numpy:
multiplier = input_shape[adv_indx_indices[adv_indx_count - 1]]
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = multiplier * tensor_indices[i]
cum_adv_index = cum_adv_index + adv_index
multiplier = multiplier * input_shape[adv_indx_indices[i]]
cum_adv_index = get_trt_tensor(
ctx, cum_adv_index, name + f"_index_sum_intermediate"
)
cum_adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_sum_intermediate_{i}",
trt.ElementWiseOperation.SUM,
cum_adv_index,
adv_index,
)
multiplier = convert_binary_elementwise(
gs-olive marked this conversation as resolved.
Show resolved Hide resolved
else:
multiplier = get_trt_tensor(
ctx,
target,
source_ir,
name + f"_index_intermediate_xj_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
dim_tensor_list[adv_indx_indices[i]],
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
name + "_dim_last",
)
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_intermediate_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
tensor_indices[i],
)
cum_adv_index = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_sum_intermediate_{i}",
trt.ElementWiseOperation.SUM,
cum_adv_index,
adv_index,
)
multiplier = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_index_intermediate_xj_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
dim_tensor_list[adv_indx_indices[i]],
)

gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0)
set_layer_name(
Expand Down
30 changes: 28 additions & 2 deletions tests/py/dynamo/conversion/test_index_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import torch
import torch.nn as nn
from .harness import DispatchTestCase
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestIndexConverter(DispatchTestCase):
def test_index_zero_two_dim(self):
Expand All @@ -27,6 +26,21 @@ def forward(self, x):
input,
)

def test_index_zero_two_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = torch.randn(2, 2)
index0 = torch.randint(0, 1, (1, 1))
index0 = index0.to(torch.int32)
self.run_test(
TestModule(),
[input, index0],
)

def test_index_zero_index_three_dim(self):
class TestModule(nn.Module):
def __init__(self):
Expand All @@ -44,6 +58,18 @@ def forward(self, x):
input,
)

def test_index_zero_index_three_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0, None]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = torch.randn(2, 2, 2)
index0 = torch.randint(0, 1, (1, 1))
index0 = index0.to(torch.int32)
self.run_test(TestModule(), [input, index0])

def test_index_zero_index_one_index_two_three_dim(self):
class TestModule(nn.Module):
def __init__(self):
Expand Down
Loading