Skip to content

Commit

Permalink
tests: fix test model paths
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
  • Loading branch information
narendasan committed Jul 26, 2022
1 parent 7393fa8 commit b26d768
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tests/py/api/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch_tensorrt as torchtrt
import torch
import torchvision.models as models
import os

def find_repo_root(max_depth=10):
dir_path = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -22,7 +23,7 @@ class TestStandardTensorInput(unittest.TestCase):
def test_compile(self):

self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.model = torch.jit.load(MODULE_DIR + "/standard_tensor_input.jit.pt").eval().to("cuda")
self.model = torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt").eval().to("cuda")

compile_spec = {
"inputs": [torchtrt.Input(self.input.shape),
Expand All @@ -41,7 +42,7 @@ class TestTupleInput(unittest.TestCase):
def test_compile(self):

self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.model = torch.jit.load(MODULE_DIR + "/tuple_input.jit.pt").eval().to("cuda")
self.model = torch.jit.load(MODULE_DIR + "/tuple_input_scripted.jit.pt").eval().to("cuda")

compile_spec = {
"input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),),
Expand All @@ -61,7 +62,7 @@ class TestListInput(unittest.TestCase):
def test_compile(self):

self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.model = torch.jit.load(MODULE_DIR + "/list_input.jit.pt").eval().to("cuda")
self.model = torch.jit.load(MODULE_DIR + "/list_input_scripted.jit.pt").eval().to("cuda")


compile_spec = {
Expand All @@ -81,7 +82,7 @@ class TestTupleInputOutput(unittest.TestCase):
def test_compile(self):

self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.model = torch.jit.load(MODULE_DIR + "/tuple_input_output.jit.pt").eval().to("cuda")
self.model = torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt").eval().to("cuda")


compile_spec = {
Expand All @@ -103,7 +104,7 @@ class TestListInputOutput(unittest.TestCase):
def test_compile(self):

self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.model = torch.jit.load(MODULE_DIR + "/list_input_output.jit.pt").eval().to("cuda")
self.model = torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt").eval().to("cuda")


compile_spec = {
Expand All @@ -126,7 +127,7 @@ class TestListInputTupleOutput(unittest.TestCase):
def test_compile(self):

self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.model = torch.jit.load(MODULE_DIR + "/list_input_tuple_output.jit.pt").eval().to("cuda")
self.model = torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt").eval().to("cuda")


compile_spec = {
Expand Down

0 comments on commit b26d768

Please sign in to comment.