Skip to content

Commit

Permalink
test: [collection] update model path in test_collection.cpp
Browse files Browse the repository at this point in the history
Signed-off-by: inocsin <vcheungyi@163.com>
  • Loading branch information
inocsin committed Mar 31, 2022
1 parent 76e3886 commit eaf1254
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
12 changes: 6 additions & 6 deletions tests/cpp/test_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

TEST(CppAPITests, TestCollectionNormalInput) {

std::string path = "/root/Torch-TensorRT/normal_model.ts";
std::string path = "tests/modules/normal_model.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
std::vector<at::Tensor> inputs;
inputs.push_back(in0);
Expand Down Expand Up @@ -53,7 +53,7 @@ TEST(CppAPITests, TestCollectionNormalInput) {

TEST(CppAPITests, TestCollectionTupleInput) {

std::string path = "/root/Torch-TensorRT/tuple_input.ts";
std::string path = "tests/modules/tuple_input.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);

torch::jit::Module mod;
Expand Down Expand Up @@ -103,7 +103,7 @@ TEST(CppAPITests, TestCollectionTupleInput) {

TEST(CppAPITests, TestCollectionListInput) {

std::string path = "/root/Torch-TensorRT/list_input.ts";
std::string path = "tests/modules/list_input.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
std::vector<at::Tensor> inputs;
inputs.push_back(in0);
Expand Down Expand Up @@ -169,7 +169,7 @@ TEST(CppAPITests, TestCollectionListInput) {

TEST(CppAPITests, TestCollectionTupleInputOutput) {

std::string path = "/root/Torch-TensorRT/tuple_input_output.ts";
std::string path = "tests/modules/tuple_input_output.jit.pt";

torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);

Expand Down Expand Up @@ -224,7 +224,7 @@ TEST(CppAPITests, TestCollectionTupleInputOutput) {

TEST(CppAPITests, TestCollectionListInputOutput) {

std::string path = "/root/Torch-TensorRT/list_input_output.ts";
std::string path = "tests/modules/list_input_output.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
std::vector<at::Tensor> inputs;
inputs.push_back(in0);
Expand Down Expand Up @@ -296,7 +296,7 @@ TEST(CppAPITests, TestCollectionListInputOutput) {

TEST(CppAPITests, TestCollectionComplexModel) {

std::string path = "/root/Torch-TensorRT/complex_model.ts";
std::string path = "tests/modules/complex_model.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
std::vector<at::Tensor> inputs;
inputs.push_back(in0);
Expand Down
12 changes: 6 additions & 6 deletions tests/modules/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,29 +256,29 @@ def forward(self, z: List[torch.Tensor]):
normal_model = Normal()
normal_model_ts = torch.jit.script(normal_model)
normal_model_ts.to("cuda").eval()
torch.jit.save(normal_model_ts, "normal_model.ts")
torch.jit.save(normal_model_ts, "normal_model.jit.pt")

tuple_input = TupleInput()
tuple_input_ts = torch.jit.script(tuple_input)
tuple_input_ts.to("cuda").eval()
torch.jit.save(tuple_input_ts, "tuple_input.ts")
torch.jit.save(tuple_input_ts, "tuple_input.jit.pt")

list_input = ListInput()
list_input_ts = torch.jit.script(list_input)
list_input_ts.to("cuda").eval()
torch.jit.save(list_input_ts, "list_input.ts")
torch.jit.save(list_input_ts, "list_input.jit.pt")

tuple_input = TupleInputOutput()
tuple_input_ts = torch.jit.script(tuple_input)
tuple_input_ts.to("cuda").eval()
torch.jit.save(tuple_input_ts, "tuple_input_output.ts")
torch.jit.save(tuple_input_ts, "tuple_input_output.jit.pt")

list_input = ListInputOutput()
list_input_ts = torch.jit.script(list_input)
list_input_ts.to("cuda").eval()
torch.jit.save(list_input_ts, "list_input_output.ts")
torch.jit.save(list_input_ts, "list_input_output.jit.pt")

complex_model = ComplexModel()
complex_model_ts = torch.jit.script(complex_model)
complex_model_ts.to("cuda").eval()
torch.jit.save(complex_model_ts, "complex_model.ts")
torch.jit.save(complex_model_ts, "complex_model.jit.pt")

0 comments on commit eaf1254

Please sign in to comment.