From d72bb52ecb826408c1bd0e597c21fb54f024a8e5 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 18 Jul 2024 20:20:40 -0700 Subject: [PATCH 1/8] add dynamic shape support for roll/scaler_tensor --- .../dynamo/conversion/aten_ops_converters.py | 7 +- .../dynamo/conversion/impl/permutation.py | 391 ++++++++++++++++-- tests/py/dynamo/conversion/test_roll_aten.py | 33 +- 3 files changed, 392 insertions(+), 39 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 3da1b09fba..514197b84d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2726,6 +2726,7 @@ def attention_validator(node: Node) -> bool: @dynamo_tensorrt_converter( torch.nn.functional.scaled_dot_product_attention, capability_validator=attention_validator, + supports_dynamic_shapes=True, ) def tensorrt_scaled_dot_product_attention( ctx: ConversionContext, @@ -3333,7 +3334,9 @@ def aten_ops_diagonal( ) -@dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default) +@dynamo_tensorrt_converter( + torch.ops.aten.scalar_tensor.default, supports_dynamic_shapes=True +) def aten_ops_scalar_tensor( ctx: ConversionContext, target: Target, @@ -3346,7 +3349,7 @@ def aten_ops_scalar_tensor( ) -@dynamo_tensorrt_converter(torch.ops.aten.roll.default) +@dynamo_tensorrt_converter(torch.ops.aten.roll.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index 4fabebd176..9b7a8644bb 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Union import tensorrt as trt from torch.fx.node import Target @@ -6,10 +6,12 @@ from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( - flatten_dims, get_positive_dim, + get_trt_tensor, + has_dynamic_shape, + set_layer_name, ) -from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.fx.types import TRTTensor @@ -34,6 +36,278 @@ def permute( return layer.get_output(0) +# def roll( +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR], +# name: str, +# input: TRTTensor, +# shifts: Union[int, Sequence[int]], +# dims: Union[int, Sequence[int]], +# ) -> TRTTensor: + +# if isinstance(shifts, int): +# shifts = [shifts] +# if isinstance(dims, int): +# dims = [dims] + +# if dims != []: +# # preprocess dims, in case that dims has multiple same dim +# # for example shifts:[1, 2, 1], dims: [1, 0, 1] +# # can be simplified to shifts: [2, 2], dims: [1, 0] +# shift_dict = {} +# for shift, dim in zip(shifts, dims): +# if dim in shift_dict: +# shift_dict[dim] += shift +# else: +# shift_dict[dim] = shift + +# is_dynamic_shape = has_dynamic_shape(input.shape) + +# # handle static shape for the input tensor: +# if not is_dynamic_shape: +# orignal_shape = input.shape +# if dims != []: +# # calculate start, stride when dims is not empty +# start = [0] * len(input.shape) +# stride = [1] * len(input.shape) +# for d, s in shift_dict.items(): +# start[d] = get_positive_dim(-s, input.shape[d]) +# else: +# # flatten input tensor +# input = impl.shuffle.reshape(ctx, target, source_ir, name+"_reshape", input, (1, -1)) +# # calculate start, stride when dims are empty +# print(f"lan added {orignal_shape=} {input.shape=}") +# start = [get_positive_dim(-shifts[0], input.shape[1])] * len(input.shape) +# stride = [1] * len(input.shape) +# print(f"lan added {start=} {stride=}") +# slice_layer = ctx.net.add_slice( +# input, +# start=start, +# shape=input.shape, +# stride=stride, +# ) +# slice_layer.mode = trt.SampleMode.WRAP +# set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) +# output = slice_layer.get_output(0) +# if dims == []: +# # reshape back +# output = impl.shuffle.reshape(ctx, target, source_ir, name+"_reshape_back", output, orignal_shape) +# return output + +# # handle dynammic shape for the input tensor +# if dims != []: +# # calculate the start and stride +# rank = len(input.shape) +# print(f"lan added {shifts=}, {dims=}, {shift_dict=}") +# start = [] +# default_tensor = get_trt_tensor(ctx, 0, name+"_get_0") +# for i in range(rank): +# start.append(default_tensor) +# stride = [1] * rank +# for d, s in shift_dict.items(): +# if s < 0: +# start[d] = get_trt_tensor(ctx, -s, name+"_ge_{d}_{s}") +# else: +# dim_length = impl.shape.shape(ctx, target, source_ir, name+"_shape", input, d) +# start[d] = impl.elementwise.sub(ctx, target, source_ir, name+"_sub", dim_length, s) +# concat_layer = ctx.net.add_concatenation(start) +# concat_layer.axis = 0 +# set_layer_name(concat_layer, target, f"{name}_gather", source_ir) +# start = concat_layer.get_output(0) +# print(f"lan added {start=} {stride=}") +# # rolling the tensor by start and stride +# slice_layer = ctx.net.add_slice( +# input, +# start=[], +# shape=[], +# stride=stride, +# ) +# slice_layer.set_input(1, start) +# slice_layer.set_input(2, get_shape_with_dynamic_shape(ctx, target, source_ir, name+"_shape", input.shape, input)) +# slice_layer.mode = trt.SampleMode.WRAP +# set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) +# return slice_layer.get_output(0) + +# else: +# # if dims is None, the tensor will be flattened before rolling and then restored to the original shape + +# # flatten the input tensor +# flattened_output = impl.shuffle.reshape( +# ctx, target, source_ir, f"{name}_reshape", input, (1, -1) +# ) + +# # calculate the start and stride +# if shifts[0] < 0: +# start_index = get_trt_tensor(ctx, -shifts[0], name+"_get") +# else: +# flattened_length = impl.shape.shape(ctx, target, source_ir, name+"_shape", flattened_output, 1) +# start_index = impl.elementwise.sub(ctx, target, source_ir, name+"_sub", flattened_length, shifts[0]) +# start, stride = [], [] +# for i in range(len(flattened_output.shape)): +# start.append(start_index) +# stride.append(1) + +# concat_layer = ctx.net.add_concatenation(start) +# concat_layer.axis = 0 +# set_layer_name(concat_layer, target, f"{name}_gather", source_ir) +# start = concat_layer.get_output(0) + +# # rolling the flattened tensor by start and stride +# slice_layer = ctx.net.add_slice( +# flattened_output, +# start=[], +# shape=[], +# stride=stride, +# ) +# slice_layer.set_input(1, start) +# slice_layer.set_input(2, get_shape_with_dynamic_shape(ctx, target, source_ir, name+"_output_shape", flattened_output.shape, flattened_output)) +# slice_layer.mode = trt.SampleMode.WRAP +# set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) +# sliced_output = slice_layer.get_output(0) + +# # reshape back to the original shape +# shape_back = get_shape_with_dynamic_shape(ctx, target, source_ir, name+"_shape_back", input.shape, input) +# shape_layer = ctx.net.add_shuffle(sliced_output) +# shape_layer.set_input(1, shape_back) +# set_layer_name(shape_layer, target, name, source_ir) +# return shape_layer.get_output(0) + + +def calc_start_by_static_shape( + input: TRTTensor, + shift_dict: Dict[int, int], +) -> Sequence[int]: + start = [0] * len(input.shape) + for d, s in shift_dict.items(): + start[d] = get_positive_dim(-s, input.shape[d]) + return start + + +def calc_start_by_dynamic_shape( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + shift_dict: Dict[int, int], +) -> Sequence[TRTTensor]: + start = [] + default_tensor = get_trt_tensor(ctx, 0, name + "_get_0") + for i in range(len(input.shape)): + start.append(default_tensor) + for d, s in shift_dict.items(): + dim_length = impl.shape.shape(ctx, target, source_ir, name + "_shape", input, d) + start[d] = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", dim_length, s + ) + concat_layer = ctx.net.add_concatenation(start) + concat_layer.axis = 0 + set_layer_name(concat_layer, target, f"{name}_gather", source_ir) + return [concat_layer.get_output(0)] + + +# def roll( +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR], +# name: str, +# input: TRTTensor, +# shifts: Union[int, Sequence[int]], +# dims: Union[int, Sequence[int]], +# ) -> TRTTensor: +# if isinstance(shifts, int): +# shifts = [shifts] +# if isinstance(dims, int): +# dims = [dims] + +# shift_dict = {} +# if dims == []: +# shift_dict[1] = shifts[0] +# else: +# # preprocess dims, in case that dims has multiple same dim +# # for example shifts:[1, 2, 1], dims: [1, 0, 1] +# # can be simplified to shifts: [2, 2], dims: [1, 0] +# for shift, dim in zip(shifts, dims): +# if dim in shift_dict: +# shift_dict[dim] += shift +# else: +# shift_dict[dim] = shift + +# is_dynamic_shape = has_dynamic_shape(input.shape) + +# # handle static shape for the input tensor: +# if not is_dynamic_shape: +# orignal_shape = input.shape +# if dims == []: +# # flatten input tensor +# input = impl.shuffle.reshape( +# ctx, target, source_ir, name + "_reshape", input, (1, -1) +# ) +# start = calc_start_by_static_shape( +# input, shift_dict +# ) +# stride = [1] * len(input.shape) +# slice_layer = ctx.net.add_slice( +# input, +# start=start, +# shape=input.shape, +# stride=stride, +# ) +# slice_layer.mode = trt.SampleMode.WRAP +# set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) +# output = slice_layer.get_output(0) +# if dims == []: +# # reshape back +# output = impl.shuffle.reshape( +# ctx, target, source_ir, name + "_reshape_back", output, orignal_shape +# ) +# else: +# # handle dynammic shape for the input tensor +# orignal_input = input +# if dims == []: +# # flatten the input tensor +# input = impl.shuffle.reshape( +# ctx, target, source_ir, f"{name}_reshape", input, (1, -1) +# ) +# start = calc_start_by_dynamic_shape( +# ctx, target, source_ir, name + "_calc", input, shift_dict +# ) +# stride = [1] * len(input.shape) +# slice_layer = ctx.net.add_slice( +# input, +# start=[], +# shape=[], +# stride=stride, +# ) +# slice_layer.set_input(1, start) +# slice_layer.set_input( +# 2, +# get_shape_with_dynamic_shape( +# ctx, target, source_ir, name + "_dynamic_shape", input.shape, input +# ), +# ) +# slice_layer.mode = trt.SampleMode.WRAP +# set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) +# output = slice_layer.get_output(0) +# if dims == []: +# # reshape back to the original shape +# shape_back = get_shape_with_dynamic_shape( +# ctx, +# target, +# source_ir, +# name + "_shape_back", +# orignal_input.shape, +# orignal_input, +# ) +# shape_layer = ctx.net.add_shuffle(output) +# shape_layer.set_input(1, shape_back) +# set_layer_name(shape_layer, target, name+"_reshape_back", source_ir) +# output = shape_layer.get_output(0) + +# return output + + def roll( ctx: ConversionContext, target: Target, @@ -43,50 +317,95 @@ def roll( shifts: Union[int, Sequence[int]], dims: Union[int, Sequence[int]], ) -> TRTTensor: - shape = input.shape if isinstance(shifts, int): shifts = [shifts] if isinstance(dims, int): dims = [dims] - if dims != []: - rank = len(shape) - start = [0] * rank - stride = [1] * rank - for i in range(len(dims)): - d = dims[i] - s = shifts[i] - start[d] += get_positive_dim( - -s, shape[d] - ) # in case that dims has multiple same dim - - layer = ctx.net.add_slice( + is_dynamic_shape = has_dynamic_shape(input.shape) + if all(isinstance(shift, TRTTensor) for shift in shifts): + is_dynamic_shift = False + else: + is_dynamic_shift = True + + shift_dict = {} + if dims == []: + shift_dict[1] = shifts[0] + else: + # preprocess dims, in case that dims has multiple same dim + # for example shifts:[1, 2, 1], dims: [1, 0, 1] + # can be simplified to shifts: [2, 2], dims: [1, 0] + for shift, dim in zip(shifts, dims): + if dim in shift_dict: + shift_dict[dim] += shift + else: + shift_dict[dim] = shift + + # handle static shape for the input tensor and shifts: + if not is_dynamic_shape and not is_dynamic_shift: + orignal_shape = input.shape + if dims == []: + # flatten input tensor + input = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape", input, (1, -1) + ) + start = calc_start_by_static_shape(input, shift_dict) + stride = [1] * len(input.shape) + slice_layer = ctx.net.add_slice( input, start=start, - shape=shape, + shape=input.shape, stride=stride, ) - layer.mode = trt.SampleMode.WRAP - set_layer_name(layer, target, f"{name}_slice_wrap", source_ir) - return layer.get_output(0) - + slice_layer.mode = trt.SampleMode.WRAP + set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) + output = slice_layer.get_output(0) + if dims == []: + # reshape back + output = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape_back", output, orignal_shape + ) else: - flatten_shape = flatten_dims(input, 0, -1) - output = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape", input, flatten_shape + # handle dynammic shape for the input tensor and shifts + orignal_input = input + if dims == []: + # flatten the input tensor + input = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape", input, (1, -1) + ) + start = calc_start_by_dynamic_shape( + ctx, target, source_ir, name + "_calc", input, shift_dict ) - start = [get_positive_dim(-shifts[0], output.shape[0])] - stride = [1] - layer = ctx.net.add_slice( - output, - start=start, - shape=flatten_shape, + stride = [1] * len(input.shape) + slice_layer = ctx.net.add_slice( + input, + start=[], + shape=[], stride=stride, ) - layer.mode = trt.SampleMode.WRAP - set_layer_name(layer, target, f"{name}_slice_wrap", source_ir) - output = layer.get_output(0) - output = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_back", output, shape + slice_layer.set_input(1, start) + slice_layer.set_input( + 2, + get_shape_with_dynamic_shape( + ctx, target, source_ir, name + "_dynamic_shape", input.shape, input + ), ) - return output + slice_layer.mode = trt.SampleMode.WRAP + set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) + output = slice_layer.get_output(0) + if dims == []: + # reshape back to the original shape + shape_back = get_shape_with_dynamic_shape( + ctx, + target, + source_ir, + name + "_shape_back", + orignal_input.shape, + orignal_input, + ) + shape_layer = ctx.net.add_shuffle(output) + shape_layer.set_input(1, shape_back) + set_layer_name(shape_layer, target, name + "_reshape_back", source_ir) + output = shape_layer.get_output(0) + + return output diff --git a/tests/py/dynamo/conversion/test_roll_aten.py b/tests/py/dynamo/conversion/test_roll_aten.py index 80e9020855..42567349c4 100644 --- a/tests/py/dynamo/conversion/test_roll_aten.py +++ b/tests/py/dynamo/conversion/test_roll_aten.py @@ -28,9 +28,10 @@ class TestRollConverter(DispatchTestCase): ], [], ), + ((2, 3), [1], [1]), ] ) - def test_roll(self, shape, shifts, dims): + def test_roll_static(self, shape, shifts, dims): class Roll(nn.Module): def forward(self, x): return torch.ops.aten.roll.default(x, shifts, dims) @@ -38,6 +39,36 @@ def forward(self, x): inputs = [torch.randn(shape)] self.run_test(Roll(), inputs) + @parameterized.expand( + [ + # dim is empty + ((2,), (3,), (4,), [1], []), + ((2, 3), (3, 4), (4, 5), [1], []), + ((2, 3), (3, 4), (4, 5), [2], []), + ((2, 3), (3, 4), (4, 5), [-15], []), + ((2, 3, 3), (3, 4, 3), (4, 5, 4), [1], []), + # dim is not empty + ((2,), (3,), (4,), [1], [0]), + ((2, 3), (3, 4), (4, 5), [1], [1]), + ((2, 3), (3, 4), (4, 5), [2, 0], [0, 1]), + ((2, 3, 4), (3, 4, 5), (4, 5, 6), [-15, -2, 1], [0, 0, 1]), + ((2, 3, 3, 5), (3, 4, 3, 5), (4, 5, 4, 6), [11, -23], [0, 1]), + ] + ) + def test_roll_dynamic(self, min_shape, opt_shape, max_shape, shifts, dims): + class Roll(nn.Module): + def forward(self, x): + return torch.ops.aten.roll.default(x, shifts, dims) + + inputs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ) + ] + self.run_test_with_dynamic_shape(Roll(), inputs) + if __name__ == "__main__": run_tests() From 02063b27497caef736f7e7501b857587624f1c0f Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 18 Jul 2024 20:23:35 -0700 Subject: [PATCH 2/8] test --- .../dynamo/conversion/impl/permutation.py | 243 +----------------- 1 file changed, 2 insertions(+), 241 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index 9b7a8644bb..e2bc954781 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -36,144 +36,6 @@ def permute( return layer.get_output(0) -# def roll( -# ctx: ConversionContext, -# target: Target, -# source_ir: Optional[SourceIR], -# name: str, -# input: TRTTensor, -# shifts: Union[int, Sequence[int]], -# dims: Union[int, Sequence[int]], -# ) -> TRTTensor: - -# if isinstance(shifts, int): -# shifts = [shifts] -# if isinstance(dims, int): -# dims = [dims] - -# if dims != []: -# # preprocess dims, in case that dims has multiple same dim -# # for example shifts:[1, 2, 1], dims: [1, 0, 1] -# # can be simplified to shifts: [2, 2], dims: [1, 0] -# shift_dict = {} -# for shift, dim in zip(shifts, dims): -# if dim in shift_dict: -# shift_dict[dim] += shift -# else: -# shift_dict[dim] = shift - -# is_dynamic_shape = has_dynamic_shape(input.shape) - -# # handle static shape for the input tensor: -# if not is_dynamic_shape: -# orignal_shape = input.shape -# if dims != []: -# # calculate start, stride when dims is not empty -# start = [0] * len(input.shape) -# stride = [1] * len(input.shape) -# for d, s in shift_dict.items(): -# start[d] = get_positive_dim(-s, input.shape[d]) -# else: -# # flatten input tensor -# input = impl.shuffle.reshape(ctx, target, source_ir, name+"_reshape", input, (1, -1)) -# # calculate start, stride when dims are empty -# print(f"lan added {orignal_shape=} {input.shape=}") -# start = [get_positive_dim(-shifts[0], input.shape[1])] * len(input.shape) -# stride = [1] * len(input.shape) -# print(f"lan added {start=} {stride=}") -# slice_layer = ctx.net.add_slice( -# input, -# start=start, -# shape=input.shape, -# stride=stride, -# ) -# slice_layer.mode = trt.SampleMode.WRAP -# set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) -# output = slice_layer.get_output(0) -# if dims == []: -# # reshape back -# output = impl.shuffle.reshape(ctx, target, source_ir, name+"_reshape_back", output, orignal_shape) -# return output - -# # handle dynammic shape for the input tensor -# if dims != []: -# # calculate the start and stride -# rank = len(input.shape) -# print(f"lan added {shifts=}, {dims=}, {shift_dict=}") -# start = [] -# default_tensor = get_trt_tensor(ctx, 0, name+"_get_0") -# for i in range(rank): -# start.append(default_tensor) -# stride = [1] * rank -# for d, s in shift_dict.items(): -# if s < 0: -# start[d] = get_trt_tensor(ctx, -s, name+"_ge_{d}_{s}") -# else: -# dim_length = impl.shape.shape(ctx, target, source_ir, name+"_shape", input, d) -# start[d] = impl.elementwise.sub(ctx, target, source_ir, name+"_sub", dim_length, s) -# concat_layer = ctx.net.add_concatenation(start) -# concat_layer.axis = 0 -# set_layer_name(concat_layer, target, f"{name}_gather", source_ir) -# start = concat_layer.get_output(0) -# print(f"lan added {start=} {stride=}") -# # rolling the tensor by start and stride -# slice_layer = ctx.net.add_slice( -# input, -# start=[], -# shape=[], -# stride=stride, -# ) -# slice_layer.set_input(1, start) -# slice_layer.set_input(2, get_shape_with_dynamic_shape(ctx, target, source_ir, name+"_shape", input.shape, input)) -# slice_layer.mode = trt.SampleMode.WRAP -# set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) -# return slice_layer.get_output(0) - -# else: -# # if dims is None, the tensor will be flattened before rolling and then restored to the original shape - -# # flatten the input tensor -# flattened_output = impl.shuffle.reshape( -# ctx, target, source_ir, f"{name}_reshape", input, (1, -1) -# ) - -# # calculate the start and stride -# if shifts[0] < 0: -# start_index = get_trt_tensor(ctx, -shifts[0], name+"_get") -# else: -# flattened_length = impl.shape.shape(ctx, target, source_ir, name+"_shape", flattened_output, 1) -# start_index = impl.elementwise.sub(ctx, target, source_ir, name+"_sub", flattened_length, shifts[0]) -# start, stride = [], [] -# for i in range(len(flattened_output.shape)): -# start.append(start_index) -# stride.append(1) - -# concat_layer = ctx.net.add_concatenation(start) -# concat_layer.axis = 0 -# set_layer_name(concat_layer, target, f"{name}_gather", source_ir) -# start = concat_layer.get_output(0) - -# # rolling the flattened tensor by start and stride -# slice_layer = ctx.net.add_slice( -# flattened_output, -# start=[], -# shape=[], -# stride=stride, -# ) -# slice_layer.set_input(1, start) -# slice_layer.set_input(2, get_shape_with_dynamic_shape(ctx, target, source_ir, name+"_output_shape", flattened_output.shape, flattened_output)) -# slice_layer.mode = trt.SampleMode.WRAP -# set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) -# sliced_output = slice_layer.get_output(0) - -# # reshape back to the original shape -# shape_back = get_shape_with_dynamic_shape(ctx, target, source_ir, name+"_shape_back", input.shape, input) -# shape_layer = ctx.net.add_shuffle(sliced_output) -# shape_layer.set_input(1, shape_back) -# set_layer_name(shape_layer, target, name, source_ir) -# return shape_layer.get_output(0) - - def calc_start_by_static_shape( input: TRTTensor, shift_dict: Dict[int, int], @@ -191,7 +53,7 @@ def calc_start_by_dynamic_shape( name: str, input: TRTTensor, shift_dict: Dict[int, int], -) -> Sequence[TRTTensor]: +) -> TRTTensor: start = [] default_tensor = get_trt_tensor(ctx, 0, name + "_get_0") for i in range(len(input.shape)): @@ -204,108 +66,7 @@ def calc_start_by_dynamic_shape( concat_layer = ctx.net.add_concatenation(start) concat_layer.axis = 0 set_layer_name(concat_layer, target, f"{name}_gather", source_ir) - return [concat_layer.get_output(0)] - - -# def roll( -# ctx: ConversionContext, -# target: Target, -# source_ir: Optional[SourceIR], -# name: str, -# input: TRTTensor, -# shifts: Union[int, Sequence[int]], -# dims: Union[int, Sequence[int]], -# ) -> TRTTensor: -# if isinstance(shifts, int): -# shifts = [shifts] -# if isinstance(dims, int): -# dims = [dims] - -# shift_dict = {} -# if dims == []: -# shift_dict[1] = shifts[0] -# else: -# # preprocess dims, in case that dims has multiple same dim -# # for example shifts:[1, 2, 1], dims: [1, 0, 1] -# # can be simplified to shifts: [2, 2], dims: [1, 0] -# for shift, dim in zip(shifts, dims): -# if dim in shift_dict: -# shift_dict[dim] += shift -# else: -# shift_dict[dim] = shift - -# is_dynamic_shape = has_dynamic_shape(input.shape) - -# # handle static shape for the input tensor: -# if not is_dynamic_shape: -# orignal_shape = input.shape -# if dims == []: -# # flatten input tensor -# input = impl.shuffle.reshape( -# ctx, target, source_ir, name + "_reshape", input, (1, -1) -# ) -# start = calc_start_by_static_shape( -# input, shift_dict -# ) -# stride = [1] * len(input.shape) -# slice_layer = ctx.net.add_slice( -# input, -# start=start, -# shape=input.shape, -# stride=stride, -# ) -# slice_layer.mode = trt.SampleMode.WRAP -# set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) -# output = slice_layer.get_output(0) -# if dims == []: -# # reshape back -# output = impl.shuffle.reshape( -# ctx, target, source_ir, name + "_reshape_back", output, orignal_shape -# ) -# else: -# # handle dynammic shape for the input tensor -# orignal_input = input -# if dims == []: -# # flatten the input tensor -# input = impl.shuffle.reshape( -# ctx, target, source_ir, f"{name}_reshape", input, (1, -1) -# ) -# start = calc_start_by_dynamic_shape( -# ctx, target, source_ir, name + "_calc", input, shift_dict -# ) -# stride = [1] * len(input.shape) -# slice_layer = ctx.net.add_slice( -# input, -# start=[], -# shape=[], -# stride=stride, -# ) -# slice_layer.set_input(1, start) -# slice_layer.set_input( -# 2, -# get_shape_with_dynamic_shape( -# ctx, target, source_ir, name + "_dynamic_shape", input.shape, input -# ), -# ) -# slice_layer.mode = trt.SampleMode.WRAP -# set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) -# output = slice_layer.get_output(0) -# if dims == []: -# # reshape back to the original shape -# shape_back = get_shape_with_dynamic_shape( -# ctx, -# target, -# source_ir, -# name + "_shape_back", -# orignal_input.shape, -# orignal_input, -# ) -# shape_layer = ctx.net.add_shuffle(output) -# shape_layer.set_input(1, shape_back) -# set_layer_name(shape_layer, target, name+"_reshape_back", source_ir) -# output = shape_layer.get_output(0) - -# return output + return concat_layer.get_output(0) def roll( From 165b91da76faa90318224b99948d4facaee302b4 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 19 Jul 2024 13:23:34 -0700 Subject: [PATCH 3/8] add more test cases --- .../dynamo/conversion/impl/permutation.py | 76 ++++++++++++------- tests/py/dynamo/conversion/test_roll_aten.py | 32 +++++++- 2 files changed, 81 insertions(+), 27 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index e2bc954781..8f86081fd0 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Sequence, Union +from typing import Optional, Sequence, Union import tensorrt as trt from torch.fx.node import Target @@ -38,8 +38,21 @@ def permute( def calc_start_by_static_shape( input: TRTTensor, - shift_dict: Dict[int, int], + shifts: Sequence[int], + dims: Sequence[int], ) -> Sequence[int]: + shift_dict = {} + if dims == []: + shift_dict[1] = shifts[0] + else: + # preprocess dims, in case that dims has multiple same dim + # for example shifts:[1, 2, 1], dims: [1, 0, 1] + # can be simplified to shifts: [2, 2], dims: [1, 0] + for shift, dim in zip(shifts, dims): + if dim in shift_dict: + shift_dict[dim] += shift + else: + shift_dict[dim] = shift start = [0] * len(input.shape) for d, s in shift_dict.items(): start[d] = get_positive_dim(-s, input.shape[d]) @@ -52,17 +65,34 @@ def calc_start_by_dynamic_shape( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - shift_dict: Dict[int, int], + shifts: Sequence[Union[int, TRTTensor]], + dims: Sequence[int], ) -> TRTTensor: - start = [] + start = [0] * len(input.shape) default_tensor = get_trt_tensor(ctx, 0, name + "_get_0") - for i in range(len(input.shape)): - start.append(default_tensor) - for d, s in shift_dict.items(): - dim_length = impl.shape.shape(ctx, target, source_ir, name + "_shape", input, d) - start[d] = impl.elementwise.sub( - ctx, target, source_ir, name + "_sub", dim_length, s + + if dims == []: + dim_length = impl.shape.shape(ctx, target, source_ir, name + "_shape", input, 1) + start[1] = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", dim_length, shifts[0] ) + else: + for d, s in zip(dims, shifts): + if isinstance(start[d], TRTTensor): + start[d] = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", start[d], s + ) + else: + dim_length = impl.shape.shape( + ctx, target, source_ir, name + "_shape", input, d + ) + start[d] = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", dim_length, s + ) + + for idx in range(len(start)): + if start[idx] == 0: + start[idx] = default_tensor concat_layer = ctx.net.add_concatenation(start) concat_layer.axis = 0 set_layer_name(concat_layer, target, f"{name}_gather", source_ir) @@ -75,9 +105,10 @@ def roll( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - shifts: Union[int, Sequence[int]], + shifts: Union[int, Sequence[Union[int, TRTTensor]]], dims: Union[int, Sequence[int]], ) -> TRTTensor: + print(f"lan added {shifts=} {dims=}") if isinstance(shifts, int): shifts = [shifts] if isinstance(dims, int): @@ -89,19 +120,6 @@ def roll( else: is_dynamic_shift = True - shift_dict = {} - if dims == []: - shift_dict[1] = shifts[0] - else: - # preprocess dims, in case that dims has multiple same dim - # for example shifts:[1, 2, 1], dims: [1, 0, 1] - # can be simplified to shifts: [2, 2], dims: [1, 0] - for shift, dim in zip(shifts, dims): - if dim in shift_dict: - shift_dict[dim] += shift - else: - shift_dict[dim] = shift - # handle static shape for the input tensor and shifts: if not is_dynamic_shape and not is_dynamic_shift: orignal_shape = input.shape @@ -110,7 +128,7 @@ def roll( input = impl.shuffle.reshape( ctx, target, source_ir, name + "_reshape", input, (1, -1) ) - start = calc_start_by_static_shape(input, shift_dict) + start = calc_start_by_static_shape(input, shifts, dims) stride = [1] * len(input.shape) slice_layer = ctx.net.add_slice( input, @@ -135,7 +153,13 @@ def roll( ctx, target, source_ir, f"{name}_reshape", input, (1, -1) ) start = calc_start_by_dynamic_shape( - ctx, target, source_ir, name + "_calc", input, shift_dict + ctx, + target, + source_ir, + name + "_calc", + input, + shifts, + dims, ) stride = [1] * len(input.shape) slice_layer = ctx.net.add_slice( diff --git a/tests/py/dynamo/conversion/test_roll_aten.py b/tests/py/dynamo/conversion/test_roll_aten.py index 42567349c4..f79f6b4d1e 100644 --- a/tests/py/dynamo/conversion/test_roll_aten.py +++ b/tests/py/dynamo/conversion/test_roll_aten.py @@ -55,7 +55,9 @@ def forward(self, x): ((2, 3, 3, 5), (3, 4, 3, 5), (4, 5, 4, 6), [11, -23], [0, 1]), ] ) - def test_roll_dynamic(self, min_shape, opt_shape, max_shape, shifts, dims): + def test_roll_dynamic_input_static_shifts( + self, min_shape, opt_shape, max_shape, shifts, dims + ): class Roll(nn.Module): def forward(self, x): return torch.ops.aten.roll.default(x, shifts, dims) @@ -69,6 +71,34 @@ def forward(self, x): ] self.run_test_with_dynamic_shape(Roll(), inputs) + @parameterized.expand( + [ + ((2, 3), (3, 4), (4, 5)), + ((2, 3, 4), (3, 4, 5), (3, 5, 5)), + ] + ) + def test_roll_dynamic_input_dynamic_shifts(self, min_shape, opt_shape, max_shape): + class Roll(nn.Module): + def forward(self, x): + dims = [0, 1] + shifts = [x.shape[d] // 2 for d in dims] + return torch.ops.aten.roll.default(x, shifts, dims) + + inputs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ) + ] + # TODO: Confirm with Dheeraj: + # this test is to test the torch.roll(input, shifts, dims) operator + # when only input is dynamic shape, both fx symbolic and dynamo tracer works as expected + # when both input and shifts are dynamic shape, I have to set use_dynamo_tracer=True + # otherwise this particular test will fail as follows: + # py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:480: UnsupportedOperatorException + self.run_test_with_dynamic_shape(Roll(), inputs, use_dynamo_tracer=True) + if __name__ == "__main__": run_tests() From 998f3819209dab130aa153ab2bc706f2f8247857 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Fri, 19 Jul 2024 13:25:31 -0700 Subject: [PATCH 4/8] test --- py/torch_tensorrt/dynamo/conversion/impl/permutation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index 8f86081fd0..ac46e33d6b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -108,7 +108,6 @@ def roll( shifts: Union[int, Sequence[Union[int, TRTTensor]]], dims: Union[int, Sequence[int]], ) -> TRTTensor: - print(f"lan added {shifts=} {dims=}") if isinstance(shifts, int): shifts = [shifts] if isinstance(dims, int): From cca33ffee9a74f9a9127d2de0f2fd348432be9a6 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 24 Jul 2024 15:57:19 -0700 Subject: [PATCH 5/8] resolve comments --- py/torch_tensorrt/dynamo/conversion/impl/permutation.py | 8 ++++---- tests/py/dynamo/conversion/test_roll_aten.py | 6 ------ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index ac46e33d6b..cc85748eee 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -113,14 +113,14 @@ def roll( if isinstance(dims, int): dims = [dims] - is_dynamic_shape = has_dynamic_shape(input.shape) + is_input_dynamic_shape = has_dynamic_shape(input.shape) if all(isinstance(shift, TRTTensor) for shift in shifts): - is_dynamic_shift = False + is_shifts_dynamic_shape = False else: - is_dynamic_shift = True + is_shifts_dynamic_shape = True # handle static shape for the input tensor and shifts: - if not is_dynamic_shape and not is_dynamic_shift: + if not is_input_dynamic_shape and not is_shifts_dynamic_shape: orignal_shape = input.shape if dims == []: # flatten input tensor diff --git a/tests/py/dynamo/conversion/test_roll_aten.py b/tests/py/dynamo/conversion/test_roll_aten.py index f79f6b4d1e..1017e22c91 100644 --- a/tests/py/dynamo/conversion/test_roll_aten.py +++ b/tests/py/dynamo/conversion/test_roll_aten.py @@ -91,12 +91,6 @@ def forward(self, x): max_shape=max_shape, ) ] - # TODO: Confirm with Dheeraj: - # this test is to test the torch.roll(input, shifts, dims) operator - # when only input is dynamic shape, both fx symbolic and dynamo tracer works as expected - # when both input and shifts are dynamic shape, I have to set use_dynamo_tracer=True - # otherwise this particular test will fail as follows: - # py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:480: UnsupportedOperatorException self.run_test_with_dynamic_shape(Roll(), inputs, use_dynamo_tracer=True) From 3876972545f08421de25ee89cdd67a974fb793e1 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 24 Jul 2024 16:03:00 -0700 Subject: [PATCH 6/8] add comments --- py/torch_tensorrt/dynamo/conversion/impl/permutation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index cc85748eee..fcd42234ca 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -35,7 +35,9 @@ def permute( set_layer_name(layer, target, name, source_ir) return layer.get_output(0) - +# for the Tensorrt Slice layer: +# we need calculate the start offset that the slice layer uses to create the output slice. +# in this static shape scenario, the start returned is the sequence of int(constant) def calc_start_by_static_shape( input: TRTTensor, shifts: Sequence[int], @@ -58,7 +60,9 @@ def calc_start_by_static_shape( start[d] = get_positive_dim(-s, input.shape[d]) return start - +# for the Tensorrt Slice layer: +# we need calculate the start offset that the slice layer uses to create the output slice. +# in this dynamic shape scenario, the start returned is the tensor def calc_start_by_dynamic_shape( ctx: ConversionContext, target: Target, From f12a83b01cf7d053b9b5ca85d0803c2863bc1824 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 24 Jul 2024 17:02:25 -0700 Subject: [PATCH 7/8] test --- py/torch_tensorrt/dynamo/conversion/impl/permutation.py | 6 +++--- tests/py/dynamo/conversion/test_roll_aten.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index fcd42234ca..d444a62410 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -118,10 +118,10 @@ def roll( dims = [dims] is_input_dynamic_shape = has_dynamic_shape(input.shape) - if all(isinstance(shift, TRTTensor) for shift in shifts): - is_shifts_dynamic_shape = False - else: + if any(isinstance(shift, TRTTensor) for shift in shifts): is_shifts_dynamic_shape = True + else: + is_shifts_dynamic_shape = False # handle static shape for the input tensor and shifts: if not is_input_dynamic_shape and not is_shifts_dynamic_shape: diff --git a/tests/py/dynamo/conversion/test_roll_aten.py b/tests/py/dynamo/conversion/test_roll_aten.py index 1017e22c91..1508844e91 100644 --- a/tests/py/dynamo/conversion/test_roll_aten.py +++ b/tests/py/dynamo/conversion/test_roll_aten.py @@ -73,6 +73,7 @@ def forward(self, x): @parameterized.expand( [ + ((2, 3), (3, 3), (4, 3)), ((2, 3), (3, 4), (4, 5)), ((2, 3, 4), (3, 4, 5), (3, 5, 5)), ] From 858ed14869baac8b9e8343720fe7bd2517939e5f Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 24 Jul 2024 17:12:46 -0700 Subject: [PATCH 8/8] test --- py/torch_tensorrt/dynamo/conversion/impl/permutation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index d444a62410..1537d0fdbe 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -35,7 +35,8 @@ def permute( set_layer_name(layer, target, name, source_ir) return layer.get_output(0) -# for the Tensorrt Slice layer: + +# for the Tensorrt Slice layer: # we need calculate the start offset that the slice layer uses to create the output slice. # in this static shape scenario, the start returned is the sequence of int(constant) def calc_start_by_static_shape( @@ -60,7 +61,8 @@ def calc_start_by_static_shape( start[d] = get_positive_dim(-s, input.shape[d]) return start -# for the Tensorrt Slice layer: + +# for the Tensorrt Slice layer: # we need calculate the start offset that the slice layer uses to create the output slice. # in this dynamic shape scenario, the start returned is the tensor def calc_start_by_dynamic_shape(