Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynamic support to roll/scaler_tensor #3023

Merged
merged 8 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2726,6 +2726,7 @@ def attention_validator(node: Node) -> bool:
@dynamo_tensorrt_converter(
torch.nn.functional.scaled_dot_product_attention,
capability_validator=attention_validator,
supports_dynamic_shapes=True,
)
def tensorrt_scaled_dot_product_attention(
ctx: ConversionContext,
Expand Down Expand Up @@ -3333,7 +3334,9 @@ def aten_ops_diagonal(
)


@dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default)
@dynamo_tensorrt_converter(
torch.ops.aten.scalar_tensor.default, supports_dynamic_shapes=True
)
def aten_ops_scalar_tensor(
ctx: ConversionContext,
target: Target,
Expand All @@ -3346,7 +3349,7 @@ def aten_ops_scalar_tensor(
)


@dynamo_tensorrt_converter(torch.ops.aten.roll.default)
@dynamo_tensorrt_converter(torch.ops.aten.roll.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
181 changes: 145 additions & 36 deletions py/torch_tensorrt/dynamo/conversion/impl/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
flatten_dims,
get_positive_dim,
get_trt_tensor,
has_dynamic_shape,
set_layer_name,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.types import TRTTensor


Expand All @@ -34,59 +36,166 @@ def permute(
return layer.get_output(0)


# for the Tensorrt Slice layer:
# we need calculate the start offset that the slice layer uses to create the output slice.
# in this static shape scenario, the start returned is the sequence of int(constant)
def calc_start_by_static_shape(
input: TRTTensor,
shifts: Sequence[int],
dims: Sequence[int],
) -> Sequence[int]:
shift_dict = {}
lanluo-nvidia marked this conversation as resolved.
Show resolved Hide resolved
if dims == []:
shift_dict[1] = shifts[0]
else:
# preprocess dims, in case that dims has multiple same dim
# for example shifts:[1, 2, 1], dims: [1, 0, 1]
# can be simplified to shifts: [2, 2], dims: [1, 0]
for shift, dim in zip(shifts, dims):
if dim in shift_dict:
shift_dict[dim] += shift
else:
shift_dict[dim] = shift
start = [0] * len(input.shape)
for d, s in shift_dict.items():
start[d] = get_positive_dim(-s, input.shape[d])
return start


# for the Tensorrt Slice layer:
# we need calculate the start offset that the slice layer uses to create the output slice.
# in this dynamic shape scenario, the start returned is the tensor
def calc_start_by_dynamic_shape(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
shifts: Sequence[Union[int, TRTTensor]],
dims: Sequence[int],
) -> TRTTensor:
start = [0] * len(input.shape)
default_tensor = get_trt_tensor(ctx, 0, name + "_get_0")

if dims == []:
dim_length = impl.shape.shape(ctx, target, source_ir, name + "_shape", input, 1)
start[1] = impl.elementwise.sub(
ctx, target, source_ir, name + "_sub", dim_length, shifts[0]
)
else:
for d, s in zip(dims, shifts):
if isinstance(start[d], TRTTensor):
start[d] = impl.elementwise.sub(
ctx, target, source_ir, name + "_sub", start[d], s
)
else:
dim_length = impl.shape.shape(
ctx, target, source_ir, name + "_shape", input, d
)
start[d] = impl.elementwise.sub(
ctx, target, source_ir, name + "_sub", dim_length, s
)

for idx in range(len(start)):
if start[idx] == 0:
start[idx] = default_tensor
concat_layer = ctx.net.add_concatenation(start)
concat_layer.axis = 0
set_layer_name(concat_layer, target, f"{name}_gather", source_ir)
return concat_layer.get_output(0)


def roll(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
shifts: Union[int, Sequence[int]],
shifts: Union[int, Sequence[Union[int, TRTTensor]]],
dims: Union[int, Sequence[int]],
) -> TRTTensor:
shape = input.shape
if isinstance(shifts, int):
shifts = [shifts]
if isinstance(dims, int):
dims = [dims]

if dims != []:
rank = len(shape)
start = [0] * rank
stride = [1] * rank
for i in range(len(dims)):
d = dims[i]
s = shifts[i]
start[d] += get_positive_dim(
-s, shape[d]
) # in case that dims has multiple same dim

layer = ctx.net.add_slice(
is_input_dynamic_shape = has_dynamic_shape(input.shape)
if any(isinstance(shift, TRTTensor) for shift in shifts):
is_shifts_dynamic_shape = True
else:
is_shifts_dynamic_shape = False

# handle static shape for the input tensor and shifts:
if not is_input_dynamic_shape and not is_shifts_dynamic_shape:
orignal_shape = input.shape
if dims == []:
lanluo-nvidia marked this conversation as resolved.
Show resolved Hide resolved
# flatten input tensor
input = impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape", input, (1, -1)
)
start = calc_start_by_static_shape(input, shifts, dims)
stride = [1] * len(input.shape)
slice_layer = ctx.net.add_slice(
input,
start=start,
shape=shape,
shape=input.shape,
stride=stride,
)
layer.mode = trt.SampleMode.WRAP
set_layer_name(layer, target, f"{name}_slice_wrap", source_ir)
return layer.get_output(0)

slice_layer.mode = trt.SampleMode.WRAP
set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir)
output = slice_layer.get_output(0)
if dims == []:
# reshape back
output = impl.shuffle.reshape(
ctx, target, source_ir, name + "_reshape_back", output, orignal_shape
)
else:
flatten_shape = flatten_dims(input, 0, -1)
output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape", input, flatten_shape
# handle dynammic shape for the input tensor and shifts
orignal_input = input
if dims == []:
# flatten the input tensor
input = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape", input, (1, -1)
)
start = calc_start_by_dynamic_shape(
ctx,
target,
source_ir,
name + "_calc",
input,
shifts,
dims,
)
start = [get_positive_dim(-shifts[0], output.shape[0])]
stride = [1]
layer = ctx.net.add_slice(
output,
start=start,
shape=flatten_shape,
stride = [1] * len(input.shape)
slice_layer = ctx.net.add_slice(
input,
start=[],
shape=[],
stride=stride,
)
layer.mode = trt.SampleMode.WRAP
set_layer_name(layer, target, f"{name}_slice_wrap", source_ir)
output = layer.get_output(0)
output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_back", output, shape
slice_layer.set_input(1, start)
slice_layer.set_input(
2,
get_shape_with_dynamic_shape(
ctx, target, source_ir, name + "_dynamic_shape", input.shape, input
),
)
return output
slice_layer.mode = trt.SampleMode.WRAP
set_layer_name(slice_layer, target, f"{name}_slice_wrap", source_ir)
output = slice_layer.get_output(0)
if dims == []:
# reshape back to the original shape
shape_back = get_shape_with_dynamic_shape(
ctx,
target,
source_ir,
name + "_shape_back",
orignal_input.shape,
orignal_input,
)
shape_layer = ctx.net.add_shuffle(output)
shape_layer.set_input(1, shape_back)
set_layer_name(shape_layer, target, name + "_reshape_back", source_ir)
output = shape_layer.get_output(0)

return output
58 changes: 57 additions & 1 deletion tests/py/dynamo/conversion/test_roll_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,72 @@ class TestRollConverter(DispatchTestCase):
],
[],
),
((2, 3), [1], [1]),
]
)
def test_roll(self, shape, shifts, dims):
def test_roll_static(self, shape, shifts, dims):
class Roll(nn.Module):
def forward(self, x):
return torch.ops.aten.roll.default(x, shifts, dims)

inputs = [torch.randn(shape)]
self.run_test(Roll(), inputs)

@parameterized.expand(
[
# dim is empty
((2,), (3,), (4,), [1], []),
((2, 3), (3, 4), (4, 5), [1], []),
((2, 3), (3, 4), (4, 5), [2], []),
((2, 3), (3, 4), (4, 5), [-15], []),
((2, 3, 3), (3, 4, 3), (4, 5, 4), [1], []),
# dim is not empty
((2,), (3,), (4,), [1], [0]),
((2, 3), (3, 4), (4, 5), [1], [1]),
((2, 3), (3, 4), (4, 5), [2, 0], [0, 1]),
((2, 3, 4), (3, 4, 5), (4, 5, 6), [-15, -2, 1], [0, 0, 1]),
((2, 3, 3, 5), (3, 4, 3, 5), (4, 5, 4, 6), [11, -23], [0, 1]),
]
)
def test_roll_dynamic_input_static_shifts(
self, min_shape, opt_shape, max_shape, shifts, dims
):
class Roll(nn.Module):
def forward(self, x):
return torch.ops.aten.roll.default(x, shifts, dims)

inputs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
)
]
self.run_test_with_dynamic_shape(Roll(), inputs)

@parameterized.expand(
[
((2, 3), (3, 3), (4, 3)),
((2, 3), (3, 4), (4, 5)),
((2, 3, 4), (3, 4, 5), (3, 5, 5)),
]
)
def test_roll_dynamic_input_dynamic_shifts(self, min_shape, opt_shape, max_shape):
class Roll(nn.Module):
def forward(self, x):
dims = [0, 1]
shifts = [x.shape[d] // 2 for d in dims]
return torch.ops.aten.roll.default(x, shifts, dims)

inputs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
)
]
self.run_test_with_dynamic_shape(Roll(), inputs, use_dynamo_tracer=True)


if __name__ == "__main__":
run_tests()
Loading