diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 3da1b09fba..514197b84d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2726,6 +2726,7 @@ def attention_validator(node: Node) -> bool: @dynamo_tensorrt_converter( torch.nn.functional.scaled_dot_product_attention, capability_validator=attention_validator, + supports_dynamic_shapes=True, ) def tensorrt_scaled_dot_product_attention( ctx: ConversionContext, @@ -3333,7 +3334,9 @@ def aten_ops_diagonal( ) -@dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default) +@dynamo_tensorrt_converter( + torch.ops.aten.scalar_tensor.default, supports_dynamic_shapes=True +) def aten_ops_scalar_tensor( ctx: ConversionContext, target: Target, @@ -3346,7 +3349,7 @@ def aten_ops_scalar_tensor( ) -@dynamo_tensorrt_converter(torch.ops.aten.roll.default) +@dynamo_tensorrt_converter(torch.ops.aten.roll.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index 4fabebd176..1537d0fdbe 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -6,10 +6,12 @@ from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( - flatten_dims, get_positive_dim, + get_trt_tensor, + has_dynamic_shape, + set_layer_name, ) -from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.fx.types import TRTTensor @@ -34,59 +36,166 @@ def permute( return layer.get_output(0) +# for the Tensorrt Slice layer: +# we need calculate the start offset that the slice layer uses to create the output slice. +# in this static shape scenario, the start returned is the sequence of int(constant) +def calc_start_by_static_shape( + input: TRTTensor, + shifts: Sequence[int], + dims: Sequence[int], +) -> Sequence[int]: + shift_dict = {} + if dims == []: + shift_dict[1] = shifts[0] + else: + # preprocess dims, in case that dims has multiple same dim + # for example shifts:[1, 2, 1], dims: [1, 0, 1] + # can be simplified to shifts: [2, 2], dims: [1, 0] + for shift, dim in zip(shifts, dims): + if dim in shift_dict: + shift_dict[dim] += shift + else: + shift_dict[dim] = shift + start = [0] * len(input.shape) + for d, s in shift_dict.items(): + start[d] = get_positive_dim(-s, input.shape[d]) + return start + + +# for the Tensorrt Slice layer: +# we need calculate the start offset that the slice layer uses to create the output slice. +# in this dynamic shape scenario, the start returned is the tensor +def calc_start_by_dynamic_shape( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + shifts: Sequence[Union[int, TRTTensor]], + dims: Sequence[int], +) -> TRTTensor: + start = [0] * len(input.shape) + default_tensor = get_trt_tensor(ctx, 0, name + "_get_0") + + if dims == []: + dim_length = impl.shape.shape(ctx, target, source_ir, name + "_shape", input, 1) + start[1] = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", dim_length, shifts[0] + ) + else: + for d, s in zip(dims, shifts): + if isinstance(start[d], TRTTensor): + start[d] = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", start[d], s + ) + else: + dim_length = impl.shape.shape( + ctx, target, source_ir, name + "_shape", input, d + ) + start[d] = impl.elementwise.sub( + ctx, target, source_ir, name + "_sub", dim_length, s + ) + + for idx in range(len(start)): + if start[idx] == 0: + start[idx] = default_tensor + concat_layer = ctx.net.add_concatenation(start) + concat_layer.axis = 0 + set_layer_name(concat_layer, target, f"{name}_gather", source_ir) + return concat_layer.get_output(0) + + def roll( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input: TRTTensor, - shifts: Union[int, Sequence[int]], + shifts: Union[int, Sequence[Union[int, TRTTensor]]], dims: Union[int, Sequence[int]], ) -> TRTTensor: - shape = input.shape if isinstance(shifts, int): shifts = [shifts] if isinstance(dims, int): dims = [dims] - if dims != []: - rank = len(shape) - start = [0] * rank - stride = [1] * rank - for i in range(len(dims)): - d = dims[i] - s = shifts[i] - start[d] += get_positive_dim( - -s, shape[d] - ) # in case that dims has multiple same dim - - layer = ctx.net.add_slice( + is_input_dynamic_shape = has_dynamic_shape(input.shape) + if any(isinstance(shift, TRTTensor) for shift in shifts): + is_shifts_dynamic_shape = True + else: + is_shifts_dynamic_shape = False + + # handle static shape for the input tensor and shifts: + if not is_input_dynamic_shape and not is_shifts_dynamic_shape: + orignal_shape = input.shape + if dims == []: + # flatten input tensor + input = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape", input, (1, -1) + ) + start = calc_start_by_static_shape(input, shifts, dims) + stride = [1] * len(input.shape) + slice_layer = ctx.net.add_slice( input, start=start, - shape=shape, + shape=input.shape, stride=stride, ) - layer.mode = trt.SampleMode.WRAP - set_layer_name(layer, target, f"{name}_slice_wrap", source_ir) - return layer.get_output(0) - + slice_layer.mode = trt.SampleMode.WRAP + set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) + output = slice_layer.get_output(0) + if dims == []: + # reshape back + output = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape_back", output, orignal_shape + ) else: - flatten_shape = flatten_dims(input, 0, -1) - output = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape", input, flatten_shape + # handle dynammic shape for the input tensor and shifts + orignal_input = input + if dims == []: + # flatten the input tensor + input = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape", input, (1, -1) + ) + start = calc_start_by_dynamic_shape( + ctx, + target, + source_ir, + name + "_calc", + input, + shifts, + dims, ) - start = [get_positive_dim(-shifts[0], output.shape[0])] - stride = [1] - layer = ctx.net.add_slice( - output, - start=start, - shape=flatten_shape, + stride = [1] * len(input.shape) + slice_layer = ctx.net.add_slice( + input, + start=[], + shape=[], stride=stride, ) - layer.mode = trt.SampleMode.WRAP - set_layer_name(layer, target, f"{name}_slice_wrap", source_ir) - output = layer.get_output(0) - output = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_back", output, shape + slice_layer.set_input(1, start) + slice_layer.set_input( + 2, + get_shape_with_dynamic_shape( + ctx, target, source_ir, name + "_dynamic_shape", input.shape, input + ), ) - return output + slice_layer.mode = trt.SampleMode.WRAP + set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir) + output = slice_layer.get_output(0) + if dims == []: + # reshape back to the original shape + shape_back = get_shape_with_dynamic_shape( + ctx, + target, + source_ir, + name + "_shape_back", + orignal_input.shape, + orignal_input, + ) + shape_layer = ctx.net.add_shuffle(output) + shape_layer.set_input(1, shape_back) + set_layer_name(shape_layer, target, name + "_reshape_back", source_ir) + output = shape_layer.get_output(0) + + return output diff --git a/tests/py/dynamo/conversion/test_roll_aten.py b/tests/py/dynamo/conversion/test_roll_aten.py index 80e9020855..1508844e91 100644 --- a/tests/py/dynamo/conversion/test_roll_aten.py +++ b/tests/py/dynamo/conversion/test_roll_aten.py @@ -28,9 +28,10 @@ class TestRollConverter(DispatchTestCase): ], [], ), + ((2, 3), [1], [1]), ] ) - def test_roll(self, shape, shifts, dims): + def test_roll_static(self, shape, shifts, dims): class Roll(nn.Module): def forward(self, x): return torch.ops.aten.roll.default(x, shifts, dims) @@ -38,6 +39,61 @@ def forward(self, x): inputs = [torch.randn(shape)] self.run_test(Roll(), inputs) + @parameterized.expand( + [ + # dim is empty + ((2,), (3,), (4,), [1], []), + ((2, 3), (3, 4), (4, 5), [1], []), + ((2, 3), (3, 4), (4, 5), [2], []), + ((2, 3), (3, 4), (4, 5), [-15], []), + ((2, 3, 3), (3, 4, 3), (4, 5, 4), [1], []), + # dim is not empty + ((2,), (3,), (4,), [1], [0]), + ((2, 3), (3, 4), (4, 5), [1], [1]), + ((2, 3), (3, 4), (4, 5), [2, 0], [0, 1]), + ((2, 3, 4), (3, 4, 5), (4, 5, 6), [-15, -2, 1], [0, 0, 1]), + ((2, 3, 3, 5), (3, 4, 3, 5), (4, 5, 4, 6), [11, -23], [0, 1]), + ] + ) + def test_roll_dynamic_input_static_shifts( + self, min_shape, opt_shape, max_shape, shifts, dims + ): + class Roll(nn.Module): + def forward(self, x): + return torch.ops.aten.roll.default(x, shifts, dims) + + inputs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ) + ] + self.run_test_with_dynamic_shape(Roll(), inputs) + + @parameterized.expand( + [ + ((2, 3), (3, 3), (4, 3)), + ((2, 3), (3, 4), (4, 5)), + ((2, 3, 4), (3, 4, 5), (3, 5, 5)), + ] + ) + def test_roll_dynamic_input_dynamic_shifts(self, min_shape, opt_shape, max_shape): + class Roll(nn.Module): + def forward(self, x): + dims = [0, 1] + shifts = [x.shape[d] // 2 for d in dims] + return torch.ops.aten.roll.default(x, shifts, dims) + + inputs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ) + ] + self.run_test_with_dynamic_shape(Roll(), inputs, use_dynamo_tracer=True) + if __name__ == "__main__": run_tests()