From 2237eebff241e3e2257e58ab39b9ab613c5b4ad6 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 27 Apr 2023 09:55:27 -0700 Subject: [PATCH] fix: Add support for default dimension in `aten.cat` - Add default `dim=0` to concatenation operator for use cases which do not have a specific concatenation dimension specified - T5 encounters this error during compilation - Add test cases to elicit error with default dimension --- .../fx/converters/aten_ops_converters.py | 2 +- .../test/converters/aten_op/test_cat_aten.py | 39 ++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index c86f2bd228..e8efc30ddf 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -358,7 +358,7 @@ def aten_ops_cat( ) -> Union[TRTTensor, Sequence[TRTTensor]]: kwargs_new = { "tensors": args[0], - "dim": args[1], + "dim": args[1] if len(args) >= 2 else 0, } return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py index 55bd7b1e8b..bf0c2bea64 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py @@ -9,7 +9,7 @@ class TestCatConverter(DispatchTestCase): @parameterized.expand( [ ("pos", 1), - # ("neg", -2), #dim can not have dynamic input + ("neg", -2), ] ) def test_cat(self, _, dim): @@ -27,7 +27,7 @@ def forward(self, x, y, z): @parameterized.expand( [ ("pos", 1), - # ("neg", -2), #dim can not have dynamic input + ("neg", -2), ] ) def test_cat_dynamic_shape(self, _, dim): @@ -53,6 +53,41 @@ def forward(self, x, y): expected_ops={torch.ops.aten.cat.default}, ) + def test_cat_no_dim(self): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.cat((x, y, z)) + + inputs = [torch.randn(2, 1, 3), torch.randn(1, 1, 3), torch.randn(3, 1, 3)] + self.run_test( + Cat(), + inputs, + expected_ops={torch.ops.aten.cat.default}, + ) + + def test_cat_dynamic_shape_no_dim(self): + class Cat(nn.Module): + def forward(self, x, y): + return torch.cat((x, y)) + + input_specs = [ + InputTensorSpec( + shape=(-1, 16, 3), + dtype=torch.float32, + shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))], + ), + InputTensorSpec( + shape=(-1, 16, 3), + dtype=torch.float32, + shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Cat(), + input_specs, + expected_ops={torch.ops.aten.cat.default}, + ) + if __name__ == "__main__": run_tests()