Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing grid lowering #2686

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ def aten_ops_fmod(

@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.default)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
aten.gelu,
aten.gelu_backward,
aten.glu_backward,
aten.grid_sampler_2d,
aten.hardshrink,
aten.hardshrink_backward,
aten.hardsigmoid,
Expand Down
171 changes: 106 additions & 65 deletions tests/py/dynamo/conversion/test_grid_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,112 +6,74 @@
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

grid_sampler_aten_ops = {
"torch.ops.aten.grid_sampler": torch.ops.aten.grid_sampler,
"torch.ops.aten.grid_sampler_2d": torch.ops.aten.grid_sampler_2d,
"torch.ops.aten.grid_sampler.default": torch.ops.aten.grid_sampler.default,
"torch.ops.aten.grid_sampler_2d.default": torch.ops.aten.grid_sampler_2d.default,
}

grid_sampler_ops = [
(
"input_grid_interpolation_nearest_sample_fill",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_clamp",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_reflect",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_fill",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 1, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_clamp",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 1, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_reflect",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 1, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_fill",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 2, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_clamp",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 2, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_reflect",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_fill_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_clamp_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_reflect_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_fill_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_clamp_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_reflect_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_fill_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_clamp_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_reflect_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 2, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
Expand All @@ -126,19 +88,98 @@ class TestGridConverter(DispatchTestCase):
grid_sampler_op[1],
grid_sampler_op[2],
grid_sampler_op[3],
grid_sampler_op[4],
)
for grid_sampler_op in grid_sampler_ops
]
)
def test_grid(self, _, op_name, op, input_shape, dim_shape):
class TestModule(nn.Module):
def __init__(self, grid_sampler_op):
super().__init__()
self.grid_sampler_op = grid_sampler_op

def forward(self, x):
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this need to be grid_sampler_aten_ops[op_name](x, grid)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gs-olive yes it could be but then in that case it would be something like-

   def test_grid(self, _, op_name, input_shape, dim_shape, padding_mode, interpolation_mode, align_corners):
        class TestModule(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
                return grid_sampler_aten_ops[op_name](x, grid, padding_mode, interpolation_mode, align_corners)

        inputs = [torch.randn(input_shape, dtype=torch.float32)]
        grid_model = TestModule()
        self.run_test(grid_model, inputs)

It would be code design choice to either encapsulate it in the lambda function in the parameters, or declare it in the above way. In my opinion both should be good.

You could let me know if you think otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - this makes sense, thanks. In that case, it seems good to go - should just need a rebase to enable testing.


inputs = [torch.randn(input_shape, dtype=torch.float32)]
grid_model = TestModule(op)
self.run_test(grid_model, inputs)

@parameterized.expand(
[
(
grid_sampler_op[0],
grid_sampler_op[1] + "_2d",
grid_sampler_op[2],
grid_sampler_op[3],
grid_sampler_op[4],
)
for grid_sampler_op in grid_sampler_ops
]
)
def test_grid_2d(self, _, op_name, op, input_shape, dim_shape):
class TestModule(nn.Module):
def __init__(self, grid_sampler_op):
super().__init__()
self.grid_sampler_op = grid_sampler_op

def forward(self, x):
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name])

inputs = [torch.randn(input_shape, dtype=torch.float32)]
grid_model = TestModule(op)
self.run_test(grid_model, inputs)

@parameterized.expand(
[
(
grid_sampler_op[0],
grid_sampler_op[1] + ".default",
grid_sampler_op[2],
grid_sampler_op[3],
grid_sampler_op[4],
)
for grid_sampler_op in grid_sampler_ops
]
)
def test_grid_default(self, _, op_name, op, input_shape, dim_shape):
class TestModule(nn.Module):
def __init__(self, grid_sampler_op):
super().__init__()
self.grid_sampler_op = grid_sampler_op

def forward(self, x):
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name])

inputs = [torch.randn(input_shape, dtype=torch.float32)]
grid_model = TestModule(op)
self.run_test(grid_model, inputs)

@parameterized.expand(
[
(
grid_sampler_op[0],
grid_sampler_op[1] + "_2d.default",
grid_sampler_op[2],
grid_sampler_op[3],
grid_sampler_op[4],
)
for grid_sampler_op in grid_sampler_ops
]
)
def test_grid(self, _, op, input_shape, dim_shape):
def test_grid_2d_default(self, _, op_name, op, input_shape, dim_shape):
class TestModule(nn.Module):
def __init__(self, grid_sampler_op):
super().__init__()
self.grid_sampler_op = grid_sampler_op

def forward(self, x):
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return self.grid_sampler_op(x, grid)
return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name])

inputs = [torch.randn(input_shape, dtype=torch.float32)]
grid_model = TestModule(op)
Expand Down
Loading