Skip to content

Commit

Permalink
fix: Add support for default dimension in aten.cat
Browse files Browse the repository at this point in the history
- Add default `dim=0` to concatenation operator for use cases which do
not have a specific concatenation dimension specified
- T5 encounters this error during compilation
- Add test cases to elicit error with default dimension
  • Loading branch information
gs-olive committed Apr 27, 2023
1 parent b3f433a commit 350e207
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def aten_ops_cat(
) -> Union[TRTTensor, Sequence[TRTTensor]]:
kwargs_new = {
"tensors": args[0],
"dim": args[1],
"dim": args[1] if len(args) >= 2 else 0,
}
return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name)

Expand Down
39 changes: 37 additions & 2 deletions py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class TestCatConverter(DispatchTestCase):
@parameterized.expand(
[
("pos", 1),
# ("neg", -2), #Dynamo tracer issue
("neg", -2),
]
)
def test_cat(self, _, dim):
Expand All @@ -27,7 +27,7 @@ def forward(self, x, y, z):
@parameterized.expand(
[
("pos", 1),
# ("neg", -2), #Dynamo tracer issue
("neg", -2),
]
)
def test_cat_dynamic_shape(self, _, dim):
Expand All @@ -53,6 +53,41 @@ def forward(self, x, y):
expected_ops={torch.ops.aten.cat.default},
)

def test_cat_no_dim(self):
class Cat(nn.Module):
def forward(self, x, y, z):
return torch.cat((x, y, z))

inputs = [torch.randn(2, 1, 3), torch.randn(1, 1, 3), torch.randn(3, 1, 3)]
self.run_test(
Cat(),
inputs,
expected_ops={torch.ops.aten.cat.default},
)

def test_cat_dynamic_shape_no_dim(self):
class Cat(nn.Module):
def forward(self, x, y):
return torch.cat((x, y))

input_specs = [
InputTensorSpec(
shape=(-1, 16, 3),
dtype=torch.float32,
shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))],
),
InputTensorSpec(
shape=(-1, 16, 3),
dtype=torch.float32,
shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))],
),
]
self.run_test_with_dynamic_shape(
Cat(),
input_specs,
expected_ops={torch.ops.aten.cat.default},
)


if __name__ == "__main__":
run_tests()

0 comments on commit 350e207

Please sign in to comment.