Skip to content

Commit

Permalink
Addressing review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Oct 19, 2023
1 parent 0fc9c75 commit 3e4363b
Showing 1 changed file with 10 additions and 22 deletions.
32 changes: 10 additions & 22 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,30 +81,21 @@ def index(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
index: Union[
TRTTensor,
Sequence[TRTTensor],
np.ndarray,
Sequence[np.ndarray],
torch.Tensor,
Sequence[torch.Tensor],
],
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
# _LOGGER.debug(f"The index shape is {index.shape}")
# check if the input is dynamic
dynamic_shape = has_dynamic_shape(input.shape)
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
# If any is not this flag will be set to False
is_numpy = True
_LOGGER.debug(f"Checking for the is_numpy flag")
for i, ind in enumerate(index):
if ind is None:
continue
if not (isinstance(ind, torch.Tensor) or isinstance(ind, np.ndarray)):
is_numpy = False
break
_LOGGER.debug(
f"Determining whether aten.index constant-index optimization can be invoked"
)
is_numpy = all(
isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
)
# here we need to check if all the index are broadcastable
# if no, then we need to broadcast
last_index = None
Expand All @@ -117,7 +108,6 @@ def index(
# other cases are kept as TRTTensor
if is_numpy:
ind = to_numpy(ind)
is_numpy = True
else:
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
if last_index is not None:
Expand Down Expand Up @@ -156,9 +146,7 @@ def index(
for i in range(rank):
dim = input_shape[i]
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
# dim_tensor_list is a list of tensors or numpy
if is_numpy:
dim_list.append(dim)
# dim_tensor_list is a list of tensors
dim_tensor_list.append(dim_tensor)

# for cases like
Expand Down Expand Up @@ -211,12 +199,12 @@ def index(
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
# // j dimension of input x.
if is_numpy:
multiplier = dim_list[adv_indx_indices[adv_indx_count - 1]]
multiplier = input_shape[adv_indx_indices[adv_indx_count - 1]]
cum_adv_index = tensor_indices[adv_indx_count - 1]
for i in range(adv_indx_count - 2, -1, -1):
adv_index = multiplier * tensor_indices[i]
cum_adv_index = cum_adv_index + adv_index
multiplier = multiplier * dim_list[adv_indx_indices[i]]
multiplier = multiplier * input_shape[adv_indx_indices[i]]
cum_adv_index = get_trt_tensor(
ctx, cum_adv_index, name + f"_index_sum_intermediate"
)
Expand Down

0 comments on commit 3e4363b

Please sign in to comment.