Skip to content

Commit

Permalink
Removing grid lowering (#2686)
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose authored and peri044 committed Apr 19, 2024
1 parent dee74c4 commit 77c4b96
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 66 deletions.
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ def aten_ops_fmod(

@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.default)
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
aten.gelu,
aten.gelu_backward,
aten.glu_backward,
aten.grid_sampler_2d,
aten.hardshrink,
aten.hardshrink_backward,
aten.hardsigmoid,
Expand Down
171 changes: 106 additions & 65 deletions tests/py/dynamo/conversion/test_grid_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,112 +6,74 @@
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

grid_sampler_aten_ops = {
"torch.ops.aten.grid_sampler": torch.ops.aten.grid_sampler,
"torch.ops.aten.grid_sampler_2d": torch.ops.aten.grid_sampler_2d,
"torch.ops.aten.grid_sampler.default": torch.ops.aten.grid_sampler.default,
"torch.ops.aten.grid_sampler_2d.default": torch.ops.aten.grid_sampler_2d.default,
}

grid_sampler_ops = [
(
"input_grid_interpolation_nearest_sample_fill",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_clamp",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_reflect",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_fill",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 1, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_clamp",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 1, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_reflect",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 1, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_fill",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 2, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_clamp",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 2, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_reflect",
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_fill_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_clamp_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_nearest_sample_reflect_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_fill_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_clamp_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_linear_sample_reflect_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_fill_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_clamp_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
(
"input_grid_interpolation_cubic_sample_reflect_2d",
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
"torch.ops.aten.grid_sampler",
(lambda x, grid, op: op(x, grid, 2, 2, True)),
[1, 1, 5, 5],
[1, 5, 2, 2],
),
Expand All @@ -126,19 +88,98 @@ class TestGridConverter(DispatchTestCase):
grid_sampler_op[1],
grid_sampler_op[2],
grid_sampler_op[3],
grid_sampler_op[4],
)
for grid_sampler_op in grid_sampler_ops
]
)
def test_grid(self, _, op_name, op, input_shape, dim_shape):
class TestModule(nn.Module):
def __init__(self, grid_sampler_op):
super().__init__()
self.grid_sampler_op = grid_sampler_op

def forward(self, x):
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name])

inputs = [torch.randn(input_shape, dtype=torch.float32)]
grid_model = TestModule(op)
self.run_test(grid_model, inputs)

@parameterized.expand(
[
(
grid_sampler_op[0],
grid_sampler_op[1] + "_2d",
grid_sampler_op[2],
grid_sampler_op[3],
grid_sampler_op[4],
)
for grid_sampler_op in grid_sampler_ops
]
)
def test_grid_2d(self, _, op_name, op, input_shape, dim_shape):
class TestModule(nn.Module):
def __init__(self, grid_sampler_op):
super().__init__()
self.grid_sampler_op = grid_sampler_op

def forward(self, x):
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name])

inputs = [torch.randn(input_shape, dtype=torch.float32)]
grid_model = TestModule(op)
self.run_test(grid_model, inputs)

@parameterized.expand(
[
(
grid_sampler_op[0],
grid_sampler_op[1] + ".default",
grid_sampler_op[2],
grid_sampler_op[3],
grid_sampler_op[4],
)
for grid_sampler_op in grid_sampler_ops
]
)
def test_grid_default(self, _, op_name, op, input_shape, dim_shape):
class TestModule(nn.Module):
def __init__(self, grid_sampler_op):
super().__init__()
self.grid_sampler_op = grid_sampler_op

def forward(self, x):
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name])

inputs = [torch.randn(input_shape, dtype=torch.float32)]
grid_model = TestModule(op)
self.run_test(grid_model, inputs)

@parameterized.expand(
[
(
grid_sampler_op[0],
grid_sampler_op[1] + "_2d.default",
grid_sampler_op[2],
grid_sampler_op[3],
grid_sampler_op[4],
)
for grid_sampler_op in grid_sampler_ops
]
)
def test_grid(self, _, op, input_shape, dim_shape):
def test_grid_2d_default(self, _, op_name, op, input_shape, dim_shape):
class TestModule(nn.Module):
def __init__(self, grid_sampler_op):
super().__init__()
self.grid_sampler_op = grid_sampler_op

def forward(self, x):
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return self.grid_sampler_op(x, grid)
return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name])

inputs = [torch.randn(input_shape, dtype=torch.float32)]
grid_model = TestModule(op)
Expand Down

0 comments on commit 77c4b96

Please sign in to comment.