Skip to content

Commit

Permalink
Index ITensor test
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Oct 30, 2023
1 parent 563ca81 commit 81f40a3
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions tests/py/dynamo/conversion/test_index_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import torch
import torch.nn as nn
from harness import DispatchTestCase
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestIndexConverter(DispatchTestCase):
def test_index_zero_two_dim(self):
Expand All @@ -27,6 +26,21 @@ def forward(self, x):
input,
)

def test_index_zero_two_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = torch.randn(2, 2)
index0 = torch.randint(0, 1, (1, 1))
index0 = index0.to(torch.int32)
self.run_test(
TestModule(),
[input, index0],
)

def test_index_zero_index_three_dim(self):
class TestModule(nn.Module):
def __init__(self):
Expand All @@ -44,6 +58,18 @@ def forward(self, x):
input,
)

def test_index_zero_index_three_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0, None]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = torch.randn(2, 2, 2)
index0 = torch.randint(0, 1, (1, 1))
index0 = index0.to(torch.int32)
self.run_test(TestModule(), [input, index0])

def test_index_zero_index_one_index_two_three_dim(self):
class TestModule(nn.Module):
def __init__(self):
Expand Down

0 comments on commit 81f40a3

Please sign in to comment.