Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support dynamic shape for aten.linear #3011

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2502,8 +2502,8 @@ def aten_ops_convolution(
)


@dynamo_tensorrt_converter(torch.ops.aten.linear.default)
@dynamo_tensorrt_converter(torch.ops.aten.linear)
@dynamo_tensorrt_converter(torch.ops.aten.linear.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.linear, supports_dynamic_shapes=True)
def aten_ops_linear(
ctx: ConversionContext,
target: Target,
Expand Down
117 changes: 86 additions & 31 deletions tests/py/dynamo/conversion/test_linear_aten.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,44 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestLinearConverter(DispatchTestCase):
@parameterized.expand(
[
("default", [1, 512], True, torch.ops.aten.linear.default),
("matrix", [5, 512], True, torch.ops.aten.linear.default),
("no_bias", [1, 512], False, torch.ops.aten.linear.default),
(
"default",
[1, 512],
True,
),
(
"matrix",
[5, 512],
True,
),
(
"no_bias",
[1, 512],
False,
),
(
"multi_dim_matrix",
[4, 5, 512],
True,
torch.ops.aten.linear.default,
),
(
"multi_dim_matrix",
[4, 5, 512],
False,
torch.ops.aten.linear.default,
),
]
)
def test_linear(self, test_name, shape, bias, op):
class TestModule(torch.nn.Module):
def test_linear(self, test_name, shape, bias):
class linear(nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn((256, 512))
Expand All @@ -39,37 +51,80 @@ def forward(self, x):
return torch.ops.aten.linear.default(x, self.weight, self.bias)

inputs = [torch.randn(shape)]
self.run_test(TestModule(), inputs)
self.run_test(linear(), inputs)

# linear will be decomposed to P531484488 and view(reshape) can not handle reshape pattern
# like (2, 3, n)->(6, n) in implicit mode which is similar to dynamic shape test below.

# Input is transposed through view [3,3,512]->[9,512]. Converter does not know dim=0 is dynamic now.

# def test_linear_with_dynamic_shape(self):
# class TestModule(torch.nn.Module):
# def __init__(self):
# super().__init__()
# self.linear = torch.nn.Linear(512, 256)
# # Input is transposed through view [3,3,512]->[9,512]. Converter does not know dim=0 is dynamic now.
@parameterized.expand(
[
(
"2d_dim",
(1, 512),
(2, 512),
(3, 512),
torch.float32,
(256, 512),
None,
),
(
"3d_one_dynamic_dim",
(1, 1, 512),
(2, 2, 512),
(3, 3, 512),
torch.float32,
(256, 512),
(256,),
),
(
"3d_two_dynamic_dim_bias",
(1, 1, 512),
(2, 2, 512),
(3, 3, 512),
torch.float32,
(256, 512),
(256,),
),
(
"3d_two_dynamic_dim_no_bias",
(1, 1, 512),
(2, 2, 512),
(3, 3, 512),
torch.float32,
(256, 512),
None,
),
]
)
def test_linear_with_dynamic_shape(
self, _, min_shape, opt_shape, max_shape, type, weight_shape, bias_shape
):
class linear(nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.rand(weight_shape)

# def forward(self, x):
# return self.linear(x)
if bias_shape:
self.bias = torch.randn(bias_shape)
else:
self.bias = None

# input_specs = [
# Input(
# shape=(-1, 3, 512),
# dtype=torch.float32,
# shape_ranges=[((1, 3, 512), (3, 3, 512), (4, 3, 512))],
# ),
# ]
# self.run_test_with_dynamic_shape(
# TestModule(),
# input_specs,
# expected_ops={torch.ops.aten.addmm.default},
# )
def forward(self, x):
return torch.ops.aten.linear.default(x, self.weight, self.bias)

## Testing with (-1, -1, 512) results into following error:
## AssertionError: Currently we only support one dynamic dim for linear and it can't be the last dim.
input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]
self.run_test_with_dynamic_shape(
linear(),
input_specs,
)


if __name__ == "__main__":
Expand Down
Loading