From b1534f8da90ac02f44a031bcd1d007a2679b8347 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 13 Oct 2023 09:24:23 -0700 Subject: [PATCH 1/6] Numpy changes for index --- .../dynamo/conversion/converter_utils.py | 3 +- .../dynamo/conversion/impl/select.py | 275 +++++++++++------- 2 files changed, 163 insertions(+), 115 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index b382c5c329..06d429df4a 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -180,8 +180,7 @@ def cast_int_int_div_trt_tensor( def broadcastable( - a: TRTTensor, - b: TRTTensor, + a: Union[TRTTensor, np.ndarray], b: Union[TRTTensor, np.ndarray] ) -> bool: "Check if two tensors are broadcastable according to torch rules" a_shape = tuple(a.shape) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 70f94cdca8..e09fc16c4e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -3,6 +3,7 @@ import numpy as np import tensorrt as trt +import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -87,6 +88,8 @@ def index( # _LOGGER.debug(f"The index shape is {index.shape}") # check if the input is dynamic dynamic_shape = has_dynamic_shape(input.shape) + #is_numpy is a flag to specify if input isa numpy + is_numpy = False # here we need to check if all the index are broadcastable # if no, then we need to broadcast @@ -95,8 +98,14 @@ def index( if ind is not None: _LOGGER.debug(f"Shape of {i} index is {ind.shape}") adv_indx_indices.append(i) - # torch.nn.parameter.Parameter=> torch.Tensor - ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}") + # torch.nn.parameter.Parameter=> numpy array + # numpy array is kept as numpy + # other cases are kept as TRTTensor + if (isinstance(ind, torch.Tensor) or (ind, np.ndarray)): + ind = to_numpy(ind) + is_numpy = True + else: + ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}") if last_index is not None: assert broadcastable( ind, last_index @@ -131,8 +140,11 @@ def index( for i in range(rank): dim = input_shape[i] dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}") - # dim_tensor_list is a list of tensors - dim_tensor_list.append(dim_tensor) + # dim_tensor_list is a list of tensors or numpy + if(is_numpy): + dim_tensor_list.append(dim) + else: + dim_tensor_list.append(dim_tensor) # for cases like # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n], @@ -150,9 +162,14 @@ def index( if i not in adv_indx_indices: new_order.append(i) _LOGGER.debug(f"The new transpose order is {new_order}") - transpose_layer.second_transpose = tuple(new_order) - set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir) - transpose_tensor = transpose_layer.get_output(0) + + transpose_tensor = None + if(is_numpy): + transpose_tensor = input[new_order] + else: + transpose_layer.second_transpose = tuple(new_order) + set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir) + transpose_tensor = transpose_layer.get_output(0) # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n] # transpose_tensor_shape = ctx.net.add_shape(transpose_tensor) @@ -165,57 +182,70 @@ def index( for i in range(adv_indx_count, rank): mult_d1 = mult_d1 * transpose_tensor_shape[i] - concat_tensor_layer = ctx.net.add_concatenation( - [ - get_trt_tensor(ctx, mult_d0, name + "_d0_shape"), - get_trt_tensor(ctx, mult_d1, name + "_d1_shape"), - ] - ) - set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir) - concat_tensor = concat_tensor_layer.get_output(0) + flatten_tensor = None + if(is_numpy): + flatten_tensor = transpose_tensor.reshape(mult_d0, mult_d1) + else: + concat_tensor_layer = ctx.net.add_concatenation( + [ + get_trt_tensor(ctx, mult_d0, name + "_d0_shape"), + get_trt_tensor(ctx, mult_d1, name + "_d1_shape"), + ] + ) + set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir) + concat_tensor = concat_tensor_layer.get_output(0) + + reshape_layer = ctx.net.add_shuffle(transpose_tensor) + reshape_layer.set_input(1, concat_tensor) + flatten_tensor = reshape_layer.get_output(0) - reshape_layer = ctx.net.add_shuffle(transpose_tensor) - # check this - reshape_layer.set_input(1, concat_tensor) - flatten_tensor = reshape_layer.get_output(0) _LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}") # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the # // j dimension of input x. - multiplier = get_trt_tensor( - ctx, - dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], - name + "_dim_last", - ) - cum_adv_index = tensor_indices[adv_indx_count - 1] - for i in range(adv_indx_count - 2, -1, -1): - adv_index = convert_binary_elementwise( - ctx, - target, - source_ir, - name + f"_index_intermediate_{i}", - trt.ElementWiseOperation.PROD, - multiplier, - tensor_indices[i], - ) - cum_adv_index = convert_binary_elementwise( - ctx, - target, - source_ir, - name + f"_index_sum_intermediate_{i}", - trt.ElementWiseOperation.SUM, - cum_adv_index, - adv_index, - ) - multiplier = convert_binary_elementwise( + if(is_numpy): + multiplier = dim_tensor_list[adv_indx_indices[adv_indx_count - 1]] + cum_adv_index = tensor_indices[adv_indx_count - 1] + for i in range(adv_indx_count - 2, -1, -1): + adv_index = multiplier * tensor_indices[i] + cum_adv_index = cum_adv_index + adv_index + multiplier = multiplier * dim_tensor_list[adv_indx_indices[i]] + else: + + multiplier = get_trt_tensor( ctx, - target, - source_ir, - name + f"_index_intermediate_xj_{i}", - trt.ElementWiseOperation.PROD, - multiplier, - dim_tensor_list[adv_indx_indices[i]], + dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], + name + "_dim_last", ) + cum_adv_index = tensor_indices[adv_indx_count - 1] + for i in range(adv_indx_count - 2, -1, -1): + adv_index = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_index_intermediate_{i}", + trt.ElementWiseOperation.PROD, + multiplier, + tensor_indices[i], + ) + cum_adv_index = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_index_sum_intermediate_{i}", + trt.ElementWiseOperation.SUM, + cum_adv_index, + adv_index, + ) + multiplier = convert_binary_elementwise( + ctx, + target, + source_ir, + name + f"_index_intermediate_xj_{i}", + trt.ElementWiseOperation.PROD, + multiplier, + dim_tensor_list[adv_indx_indices[i]], + ) gather_layer_element = ctx.net.add_gather(flatten_tensor, cum_adv_index, 0) set_layer_name( @@ -239,29 +269,36 @@ def index( == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1 ): _LOGGER.debug(f"The indices are continuous in this case") - concat_tensor_reshape.append( - get_trt_tensor(ctx, -1, name + "_dynamic_concat") - ) + if(is_numpy): + concat_tensor_reshape.append(-1) + else: + concat_tensor_reshape.append( + get_trt_tensor(ctx, -1, name + "_dynamic_concat") + ) for i in range(0, rank): if i not in adv_indx_indices: curr_dim = dim_tensor_list[i] concat_tensor_reshape.append(curr_dim) - concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape) - set_layer_name( - concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir - ) - concat_tensor = concat_tensor_layer.get_output(0) + unfold_tensor = None + if(is_numpy): + unfold_tensor = gather_out.reshape(concat_tensor) + else: + concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape) + set_layer_name( + concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir + ) + concat_tensor = concat_tensor_layer.get_output(0) - regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out) - regular_index_shuffle_layer.set_input(1, concat_tensor) - set_layer_name( - regular_index_shuffle_layer, - target, - name + "_index_regular_index", - source_ir, - ) - unfold_tensor = regular_index_shuffle_layer.get_output(0) + regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out) + regular_index_shuffle_layer.set_input(1, concat_tensor) + set_layer_name( + regular_index_shuffle_layer, + target, + name + "_index_regular_index", + source_ir, + ) + unfold_tensor = regular_index_shuffle_layer.get_output(0) _LOGGER.debug(f"The tensor is unfolded now") _LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}") @@ -275,14 +312,18 @@ def index( new_order.append(i) _LOGGER.debug(f"Transposing the indices to correct position {new_order}") - transpose_advanced_shuffle_layer.second_transpose = tuple(new_order) - set_layer_name( - transpose_advanced_shuffle_layer, - target, - name + "_index_advanced_shuffle_transpose", - source_ir, - ) - transpose_tensor = transpose_advanced_shuffle_layer.get_output(0) + transpose_tensor = None + if(is_numpy): + transpose_tensor = unfold_tensor[new_order] + else: + transpose_advanced_shuffle_layer.second_transpose = tuple(new_order) + set_layer_name( + transpose_advanced_shuffle_layer, + target, + name + "_index_advanced_shuffle_transpose", + source_ir, + ) + transpose_tensor = transpose_advanced_shuffle_layer.get_output(0) # unfold advanced layer concat_final_tensor = [] @@ -296,25 +337,29 @@ def index( current_dim = dim_tensor_list[i] concat_final_tensor.append(current_dim) - concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor) - set_layer_name( - concat_final_shape_layer, - target, - name + "_index_continuous_concat_final_shape_layer", - source_ir, - ) - concat_final_tensor = concat_final_shape_layer.get_output(0) - - unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor) - # check this - unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor) - set_layer_name( - unfold_advanced_shuffle_layer, - target, - name + "_unfold_advanced_index", - source_ir, - ) - reshape_output = unfold_advanced_shuffle_layer.get_output(0) + reshape_output = [] + if(is_numpy): + reshape_output = transpose_tensor.reshape(concat_final_tensor) + else: + concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor) + set_layer_name( + concat_final_shape_layer, + target, + name + "_index_continuous_concat_final_shape_layer", + source_ir, + ) + concat_final_tensor = concat_final_shape_layer.get_output(0) + + unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor) + # check this + unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor) + set_layer_name( + unfold_advanced_shuffle_layer, + target, + name + "_unfold_advanced_index", + source_ir, + ) + reshape_output = unfold_advanced_shuffle_layer.get_output(0) else: _LOGGER.debug(f"The indices are not continuous in this case") @@ -325,23 +370,27 @@ def index( curr_dim = dim_tensor_list[i] concat_final_tensor.append(curr_dim) - concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor) - set_layer_name( - concat_final_shape_layer, - target, - name + "_index_non_continuous_concat_final_shape_layer", - source_ir, - ) - concat_final_tensor = concat_final_shape_layer.get_output(0) - - reshape_layer = ctx.net.add_shuffle(gather_out) - reshape_layer.set_input(1, concat_final_tensor) - set_layer_name( - reshape_layer, - target, - name + "_index_non_continuous_shuffle_final_shape_layer", - source_ir, - ) - reshape_output = reshape_layer.get_output(0) + reshape_output = None + if(is_numpy): + reshape_output = gather_out.reshape(concat_final_tensor) + else: + concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor) + set_layer_name( + concat_final_shape_layer, + target, + name + "_index_non_continuous_concat_final_shape_layer", + source_ir, + ) + concat_final_tensor = concat_final_shape_layer.get_output(0) + + reshape_layer = ctx.net.add_shuffle(gather_out) + reshape_layer.set_input(1, concat_final_tensor) + set_layer_name( + reshape_layer, + target, + name + "_index_non_continuous_shuffle_final_shape_layer", + source_ir, + ) + reshape_output = reshape_layer.get_output(0) return reshape_output From 0fc9c752e1dfd530d1ae0e13a1d6354118629a54 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 16 Oct 2023 12:23:22 -0700 Subject: [PATCH 2/6] Correction for the is_numpy cases for mix of numpy and non numpy inputs --- .../dynamo/conversion/impl/select.py | 226 +++++++++--------- 1 file changed, 108 insertions(+), 118 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index e09fc16c4e..5bf7bbf0e6 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -81,16 +81,30 @@ def index( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - index: Union[TRTTensor, Sequence[TRTTensor]], + index: Union[ + TRTTensor, + Sequence[TRTTensor], + np.ndarray, + Sequence[np.ndarray], + torch.Tensor, + Sequence[torch.Tensor], + ], ) -> TRTTensor: adv_indx_indices = [] tensor_indices = [] # _LOGGER.debug(f"The index shape is {index.shape}") # check if the input is dynamic dynamic_shape = has_dynamic_shape(input.shape) - #is_numpy is a flag to specify if input isa numpy - is_numpy = False - + # is_numpy is a flag to specify if all the indices are numpy or torchTensor. + # If any is not this flag will be set to False + is_numpy = True + _LOGGER.debug(f"Checking for the is_numpy flag") + for i, ind in enumerate(index): + if ind is None: + continue + if not (isinstance(ind, torch.Tensor) or isinstance(ind, np.ndarray)): + is_numpy = False + break # here we need to check if all the index are broadcastable # if no, then we need to broadcast last_index = None @@ -101,7 +115,7 @@ def index( # torch.nn.parameter.Parameter=> numpy array # numpy array is kept as numpy # other cases are kept as TRTTensor - if (isinstance(ind, torch.Tensor) or (ind, np.ndarray)): + if is_numpy: ind = to_numpy(ind) is_numpy = True else: @@ -119,8 +133,9 @@ def index( set_layer_name(identity_layer, target, name + "_index_identity", source_ir) return identity_layer.get_output(0) elif len(tensor_indices) == 1: - # This case works - indices_tensor = tensor_indices[0] + indices_tensor = get_trt_tensor( + ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor" + ) index = adv_indx_indices[0] _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}") gather_layer = ctx.net.add_gather(input, indices_tensor, index) @@ -136,15 +151,15 @@ def index( rank = len(input_shape) adv_indx_count = len(adv_indx_indices) dim_tensor_list = [] + dim_list = [] for i in range(rank): dim = input_shape[i] dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}") # dim_tensor_list is a list of tensors or numpy - if(is_numpy): - dim_tensor_list.append(dim) - else: - dim_tensor_list.append(dim_tensor) + if is_numpy: + dim_list.append(dim) + dim_tensor_list.append(dim_tensor) # for cases like # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n], @@ -163,13 +178,9 @@ def index( new_order.append(i) _LOGGER.debug(f"The new transpose order is {new_order}") - transpose_tensor = None - if(is_numpy): - transpose_tensor = input[new_order] - else: - transpose_layer.second_transpose = tuple(new_order) - set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir) - transpose_tensor = transpose_layer.get_output(0) + transpose_layer.second_transpose = tuple(new_order) + set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir) + transpose_tensor = transpose_layer.get_output(0) # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_n] # transpose_tensor_shape = ctx.net.add_shape(transpose_tensor) @@ -182,36 +193,34 @@ def index( for i in range(adv_indx_count, rank): mult_d1 = mult_d1 * transpose_tensor_shape[i] - flatten_tensor = None - if(is_numpy): - flatten_tensor = transpose_tensor.reshape(mult_d0, mult_d1) - else: - concat_tensor_layer = ctx.net.add_concatenation( - [ - get_trt_tensor(ctx, mult_d0, name + "_d0_shape"), - get_trt_tensor(ctx, mult_d1, name + "_d1_shape"), - ] - ) - set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir) - concat_tensor = concat_tensor_layer.get_output(0) + concat_tensor_layer = ctx.net.add_concatenation( + [ + get_trt_tensor(ctx, mult_d0, name + "_d0_shape"), + get_trt_tensor(ctx, mult_d1, name + "_d1_shape"), + ] + ) + set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir) + concat_tensor = concat_tensor_layer.get_output(0) - reshape_layer = ctx.net.add_shuffle(transpose_tensor) - reshape_layer.set_input(1, concat_tensor) - flatten_tensor = reshape_layer.get_output(0) + reshape_layer = ctx.net.add_shuffle(transpose_tensor) + reshape_layer.set_input(1, concat_tensor) + flatten_tensor = reshape_layer.get_output(0) _LOGGER.debug(f"The flatten tensor shape is {flatten_tensor.shape}") # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the # // j dimension of input x. - if(is_numpy): - multiplier = dim_tensor_list[adv_indx_indices[adv_indx_count - 1]] + if is_numpy: + multiplier = dim_list[adv_indx_indices[adv_indx_count - 1]] cum_adv_index = tensor_indices[adv_indx_count - 1] for i in range(adv_indx_count - 2, -1, -1): adv_index = multiplier * tensor_indices[i] cum_adv_index = cum_adv_index + adv_index - multiplier = multiplier * dim_tensor_list[adv_indx_indices[i]] + multiplier = multiplier * dim_list[adv_indx_indices[i]] + cum_adv_index = get_trt_tensor( + ctx, cum_adv_index, name + f"_index_sum_intermediate" + ) else: - multiplier = get_trt_tensor( ctx, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], @@ -269,36 +278,29 @@ def index( == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1 ): _LOGGER.debug(f"The indices are continuous in this case") - if(is_numpy): - concat_tensor_reshape.append(-1) - else: - concat_tensor_reshape.append( - get_trt_tensor(ctx, -1, name + "_dynamic_concat") - ) + concat_tensor_reshape.append( + get_trt_tensor(ctx, -1, name + "_dynamic_concat") + ) for i in range(0, rank): if i not in adv_indx_indices: curr_dim = dim_tensor_list[i] concat_tensor_reshape.append(curr_dim) - unfold_tensor = None - if(is_numpy): - unfold_tensor = gather_out.reshape(concat_tensor) - else: - concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape) - set_layer_name( - concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir - ) - concat_tensor = concat_tensor_layer.get_output(0) + concat_tensor_layer = ctx.net.add_concatenation(concat_tensor_reshape) + set_layer_name( + concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir + ) + concat_tensor = concat_tensor_layer.get_output(0) - regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out) - regular_index_shuffle_layer.set_input(1, concat_tensor) - set_layer_name( - regular_index_shuffle_layer, - target, - name + "_index_regular_index", - source_ir, - ) - unfold_tensor = regular_index_shuffle_layer.get_output(0) + regular_index_shuffle_layer = ctx.net.add_shuffle(gather_out) + regular_index_shuffle_layer.set_input(1, concat_tensor) + set_layer_name( + regular_index_shuffle_layer, + target, + name + "_index_regular_index", + source_ir, + ) + unfold_tensor = regular_index_shuffle_layer.get_output(0) _LOGGER.debug(f"The tensor is unfolded now") _LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}") @@ -312,18 +314,14 @@ def index( new_order.append(i) _LOGGER.debug(f"Transposing the indices to correct position {new_order}") - transpose_tensor = None - if(is_numpy): - transpose_tensor = unfold_tensor[new_order] - else: - transpose_advanced_shuffle_layer.second_transpose = tuple(new_order) - set_layer_name( - transpose_advanced_shuffle_layer, - target, - name + "_index_advanced_shuffle_transpose", - source_ir, - ) - transpose_tensor = transpose_advanced_shuffle_layer.get_output(0) + transpose_advanced_shuffle_layer.second_transpose = tuple(new_order) + set_layer_name( + transpose_advanced_shuffle_layer, + target, + name + "_index_advanced_shuffle_transpose", + source_ir, + ) + transpose_tensor = transpose_advanced_shuffle_layer.get_output(0) # unfold advanced layer concat_final_tensor = [] @@ -337,29 +335,25 @@ def index( current_dim = dim_tensor_list[i] concat_final_tensor.append(current_dim) - reshape_output = [] - if(is_numpy): - reshape_output = transpose_tensor.reshape(concat_final_tensor) - else: - concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor) - set_layer_name( - concat_final_shape_layer, - target, - name + "_index_continuous_concat_final_shape_layer", - source_ir, - ) - concat_final_tensor = concat_final_shape_layer.get_output(0) - - unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor) - # check this - unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor) - set_layer_name( - unfold_advanced_shuffle_layer, - target, - name + "_unfold_advanced_index", - source_ir, - ) - reshape_output = unfold_advanced_shuffle_layer.get_output(0) + concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor) + set_layer_name( + concat_final_shape_layer, + target, + name + "_index_continuous_concat_final_shape_layer", + source_ir, + ) + concat_final_tensor = concat_final_shape_layer.get_output(0) + + unfold_advanced_shuffle_layer = ctx.net.add_shuffle(transpose_tensor) + # check this + unfold_advanced_shuffle_layer.set_input(1, concat_final_tensor) + set_layer_name( + unfold_advanced_shuffle_layer, + target, + name + "_unfold_advanced_index", + source_ir, + ) + reshape_output = unfold_advanced_shuffle_layer.get_output(0) else: _LOGGER.debug(f"The indices are not continuous in this case") @@ -370,27 +364,23 @@ def index( curr_dim = dim_tensor_list[i] concat_final_tensor.append(curr_dim) - reshape_output = None - if(is_numpy): - reshape_output = gather_out.reshape(concat_final_tensor) - else: - concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor) - set_layer_name( - concat_final_shape_layer, - target, - name + "_index_non_continuous_concat_final_shape_layer", - source_ir, - ) - concat_final_tensor = concat_final_shape_layer.get_output(0) - - reshape_layer = ctx.net.add_shuffle(gather_out) - reshape_layer.set_input(1, concat_final_tensor) - set_layer_name( - reshape_layer, - target, - name + "_index_non_continuous_shuffle_final_shape_layer", - source_ir, - ) - reshape_output = reshape_layer.get_output(0) + concat_final_shape_layer = ctx.net.add_concatenation(concat_final_tensor) + set_layer_name( + concat_final_shape_layer, + target, + name + "_index_non_continuous_concat_final_shape_layer", + source_ir, + ) + concat_final_tensor = concat_final_shape_layer.get_output(0) + + reshape_layer = ctx.net.add_shuffle(gather_out) + reshape_layer.set_input(1, concat_final_tensor) + set_layer_name( + reshape_layer, + target, + name + "_index_non_continuous_shuffle_final_shape_layer", + source_ir, + ) + reshape_output = reshape_layer.get_output(0) return reshape_output From 97e3758ca3dd4ed30ed9137b253604d829d37fd6 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 19 Oct 2023 15:18:57 -0700 Subject: [PATCH 3/6] Addressing review comments --- .../dynamo/conversion/impl/select.py | 33 ++++++------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 5bf7bbf0e6..b2e1b86a62 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -81,30 +81,20 @@ def index( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - index: Union[ - TRTTensor, - Sequence[TRTTensor], - np.ndarray, - Sequence[np.ndarray], - torch.Tensor, - Sequence[torch.Tensor], - ], + index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], ) -> TRTTensor: adv_indx_indices = [] tensor_indices = [] - # _LOGGER.debug(f"The index shape is {index.shape}") # check if the input is dynamic dynamic_shape = has_dynamic_shape(input.shape) # is_numpy is a flag to specify if all the indices are numpy or torchTensor. # If any is not this flag will be set to False - is_numpy = True - _LOGGER.debug(f"Checking for the is_numpy flag") - for i, ind in enumerate(index): - if ind is None: - continue - if not (isinstance(ind, torch.Tensor) or isinstance(ind, np.ndarray)): - is_numpy = False - break + _LOGGER.debug( + f"Determining whether aten.index constant-index optimization can be invoked" + ) + is_numpy = all( + isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None + ) # here we need to check if all the index are broadcastable # if no, then we need to broadcast last_index = None @@ -117,7 +107,6 @@ def index( # other cases are kept as TRTTensor if is_numpy: ind = to_numpy(ind) - is_numpy = True else: ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}") if last_index is not None: @@ -156,9 +145,7 @@ def index( for i in range(rank): dim = input_shape[i] dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}") - # dim_tensor_list is a list of tensors or numpy - if is_numpy: - dim_list.append(dim) + # dim_tensor_list is a list of tensors dim_tensor_list.append(dim_tensor) # for cases like @@ -211,12 +198,12 @@ def index( # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the # // j dimension of input x. if is_numpy: - multiplier = dim_list[adv_indx_indices[adv_indx_count - 1]] + multiplier = input_shape[adv_indx_indices[adv_indx_count - 1]] cum_adv_index = tensor_indices[adv_indx_count - 1] for i in range(adv_indx_count - 2, -1, -1): adv_index = multiplier * tensor_indices[i] cum_adv_index = cum_adv_index + adv_index - multiplier = multiplier * dim_list[adv_indx_indices[i]] + multiplier = multiplier * input_shape[adv_indx_indices[i]] cum_adv_index = get_trt_tensor( ctx, cum_adv_index, name + f"_index_sum_intermediate" ) From 563ca81d0b8ce61955a2b37b5252ba8b9c394dc9 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 20 Oct 2023 15:43:50 -0700 Subject: [PATCH 4/6] removing dim_list --- py/torch_tensorrt/dynamo/conversion/impl/select.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index b2e1b86a62..db586be65f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -140,7 +140,6 @@ def index( rank = len(input_shape) adv_indx_count = len(adv_indx_indices) dim_tensor_list = [] - dim_list = [] for i in range(rank): dim = input_shape[i] From 81f40a32cb928a5b6680485e0a178a6be33c57e5 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 30 Oct 2023 12:36:18 -0700 Subject: [PATCH 5/6] Index ITensor test --- tests/py/dynamo/conversion/test_index_aten.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 393eb53c63..24a201aebf 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -2,11 +2,10 @@ import torch import torch.nn as nn +from harness import DispatchTestCase from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input -from .harness import DispatchTestCase - class TestIndexConverter(DispatchTestCase): def test_index_zero_two_dim(self): @@ -27,6 +26,21 @@ def forward(self, x): input, ) + def test_index_zero_two_dim_ITensor(self): + class TestModule(nn.Module): + def forward(self, x, index0): + indices = [None, index0] + out = torch.ops.aten.index.Tensor(x, indices) + return out + + input = torch.randn(2, 2) + index0 = torch.randint(0, 1, (1, 1)) + index0 = index0.to(torch.int32) + self.run_test( + TestModule(), + [input, index0], + ) + def test_index_zero_index_three_dim(self): class TestModule(nn.Module): def __init__(self): @@ -44,6 +58,18 @@ def forward(self, x): input, ) + def test_index_zero_index_three_dim_ITensor(self): + class TestModule(nn.Module): + def forward(self, x, index0): + indices = [None, index0, None] + out = torch.ops.aten.index.Tensor(x, indices) + return out + + input = torch.randn(2, 2, 2) + index0 = torch.randint(0, 1, (1, 1)) + index0 = index0.to(torch.int32) + self.run_test(TestModule(), [input, index0]) + def test_index_zero_index_one_index_two_three_dim(self): class TestModule(nn.Module): def __init__(self): From 6741707557e614ee3ce19fd9175b15379e4c8fd3 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 6 Nov 2023 08:31:34 -0800 Subject: [PATCH 6/6] Correcting the .harness in test_index_aten.py --- tests/py/dynamo/conversion/test_index_aten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 24a201aebf..df61a4b835 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from harness import DispatchTestCase +from .harness import DispatchTestCase from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input