diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 45949a1c8d..bb382da15b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -332,6 +332,8 @@ def aten_ops_fmod( @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.default) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.default) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index de791851db..d365d6d9c5 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -46,7 +46,6 @@ aten.gelu, aten.gelu_backward, aten.glu_backward, - aten.grid_sampler_2d, aten.hardshrink, aten.hardshrink_backward, aten.hardsigmoid, diff --git a/tests/py/dynamo/conversion/test_grid_aten.py b/tests/py/dynamo/conversion/test_grid_aten.py index 32480110f3..e3b5783b19 100644 --- a/tests/py/dynamo/conversion/test_grid_aten.py +++ b/tests/py/dynamo/conversion/test_grid_aten.py @@ -6,112 +6,74 @@ from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +grid_sampler_aten_ops = { + "torch.ops.aten.grid_sampler": torch.ops.aten.grid_sampler, + "torch.ops.aten.grid_sampler_2d": torch.ops.aten.grid_sampler_2d, + "torch.ops.aten.grid_sampler.default": torch.ops.aten.grid_sampler.default, + "torch.ops.aten.grid_sampler_2d.default": torch.ops.aten.grid_sampler_2d.default, +} + grid_sampler_ops = [ ( "input_grid_interpolation_nearest_sample_fill", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 0, 0, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_nearest_sample_clamp", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 0, 1, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_nearest_sample_reflect", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 0, 2, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_linear_sample_fill", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 1, 0, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_linear_sample_clamp", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 1, 1, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_linear_sample_reflect", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 1, 2, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_cubic_sample_fill", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 2, 0, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_cubic_sample_clamp", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 2, 1, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_cubic_sample_reflect", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_nearest_sample_fill_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_nearest_sample_clamp_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_nearest_sample_reflect_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_linear_sample_fill_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_linear_sample_clamp_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_linear_sample_reflect_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_cubic_sample_fill_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_cubic_sample_clamp_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_cubic_sample_reflect_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 2, 2, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), @@ -126,11 +88,90 @@ class TestGridConverter(DispatchTestCase): grid_sampler_op[1], grid_sampler_op[2], grid_sampler_op[3], + grid_sampler_op[4], + ) + for grid_sampler_op in grid_sampler_ops + ] + ) + def test_grid(self, _, op_name, op, input_shape, dim_shape): + class TestModule(nn.Module): + def __init__(self, grid_sampler_op): + super().__init__() + self.grid_sampler_op = grid_sampler_op + + def forward(self, x): + grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) + return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name]) + + inputs = [torch.randn(input_shape, dtype=torch.float32)] + grid_model = TestModule(op) + self.run_test(grid_model, inputs) + + @parameterized.expand( + [ + ( + grid_sampler_op[0], + grid_sampler_op[1] + "_2d", + grid_sampler_op[2], + grid_sampler_op[3], + grid_sampler_op[4], + ) + for grid_sampler_op in grid_sampler_ops + ] + ) + def test_grid_2d(self, _, op_name, op, input_shape, dim_shape): + class TestModule(nn.Module): + def __init__(self, grid_sampler_op): + super().__init__() + self.grid_sampler_op = grid_sampler_op + + def forward(self, x): + grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) + return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name]) + + inputs = [torch.randn(input_shape, dtype=torch.float32)] + grid_model = TestModule(op) + self.run_test(grid_model, inputs) + + @parameterized.expand( + [ + ( + grid_sampler_op[0], + grid_sampler_op[1] + ".default", + grid_sampler_op[2], + grid_sampler_op[3], + grid_sampler_op[4], + ) + for grid_sampler_op in grid_sampler_ops + ] + ) + def test_grid_default(self, _, op_name, op, input_shape, dim_shape): + class TestModule(nn.Module): + def __init__(self, grid_sampler_op): + super().__init__() + self.grid_sampler_op = grid_sampler_op + + def forward(self, x): + grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) + return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name]) + + inputs = [torch.randn(input_shape, dtype=torch.float32)] + grid_model = TestModule(op) + self.run_test(grid_model, inputs) + + @parameterized.expand( + [ + ( + grid_sampler_op[0], + grid_sampler_op[1] + "_2d.default", + grid_sampler_op[2], + grid_sampler_op[3], + grid_sampler_op[4], ) for grid_sampler_op in grid_sampler_ops ] ) - def test_grid(self, _, op, input_shape, dim_shape): + def test_grid_2d_default(self, _, op_name, op, input_shape, dim_shape): class TestModule(nn.Module): def __init__(self, grid_sampler_op): super().__init__() @@ -138,7 +179,7 @@ def __init__(self, grid_sampler_op): def forward(self, x): grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) - return self.grid_sampler_op(x, grid) + return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name]) inputs = [torch.randn(input_shape, dtype=torch.float32)] grid_model = TestModule(op)