From 81f40a32cb928a5b6680485e0a178a6be33c57e5 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 30 Oct 2023 12:36:18 -0700 Subject: [PATCH] Index ITensor test --- tests/py/dynamo/conversion/test_index_aten.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 393eb53c63..24a201aebf 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -2,11 +2,10 @@ import torch import torch.nn as nn +from harness import DispatchTestCase from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input -from .harness import DispatchTestCase - class TestIndexConverter(DispatchTestCase): def test_index_zero_two_dim(self): @@ -27,6 +26,21 @@ def forward(self, x): input, ) + def test_index_zero_two_dim_ITensor(self): + class TestModule(nn.Module): + def forward(self, x, index0): + indices = [None, index0] + out = torch.ops.aten.index.Tensor(x, indices) + return out + + input = torch.randn(2, 2) + index0 = torch.randint(0, 1, (1, 1)) + index0 = index0.to(torch.int32) + self.run_test( + TestModule(), + [input, index0], + ) + def test_index_zero_index_three_dim(self): class TestModule(nn.Module): def __init__(self): @@ -44,6 +58,18 @@ def forward(self, x): input, ) + def test_index_zero_index_three_dim_ITensor(self): + class TestModule(nn.Module): + def forward(self, x, index0): + indices = [None, index0, None] + out = torch.ops.aten.index.Tensor(x, indices) + return out + + input = torch.randn(2, 2, 2) + index0 = torch.randint(0, 1, (1, 1)) + index0 = index0.to(torch.int32) + self.run_test(TestModule(), [input, index0]) + def test_index_zero_index_one_index_two_three_dim(self): class TestModule(nn.Module): def __init__(self):