diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 4b4a0807aa..485158648c 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -359,7 +359,7 @@ def aten_ops_cat( ) -> Union[TRTTensor, Sequence[TRTTensor]]: kwargs_new = { "tensors": args[0], - "dim": args[1], + "dim": args[1] if len(args) >= 2 else 0, } return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py index 55bd7b1e8b..bf0c2bea64 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py @@ -9,7 +9,7 @@ class TestCatConverter(DispatchTestCase): @parameterized.expand( [ ("pos", 1), - # ("neg", -2), #dim can not have dynamic input + ("neg", -2), ] ) def test_cat(self, _, dim): @@ -27,7 +27,7 @@ def forward(self, x, y, z): @parameterized.expand( [ ("pos", 1), - # ("neg", -2), #dim can not have dynamic input + ("neg", -2), ] ) def test_cat_dynamic_shape(self, _, dim): @@ -53,6 +53,41 @@ def forward(self, x, y): expected_ops={torch.ops.aten.cat.default}, ) + def test_cat_no_dim(self): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.cat((x, y, z)) + + inputs = [torch.randn(2, 1, 3), torch.randn(1, 1, 3), torch.randn(3, 1, 3)] + self.run_test( + Cat(), + inputs, + expected_ops={torch.ops.aten.cat.default}, + ) + + def test_cat_dynamic_shape_no_dim(self): + class Cat(nn.Module): + def forward(self, x, y): + return torch.cat((x, y)) + + input_specs = [ + InputTensorSpec( + shape=(-1, 16, 3), + dtype=torch.float32, + shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))], + ), + InputTensorSpec( + shape=(-1, 16, 3), + dtype=torch.float32, + shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Cat(), + input_specs, + expected_ops={torch.ops.aten.cat.default}, + ) + if __name__ == "__main__": run_tests()