Skip to content

Commit

Permalink
tests: Fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
  • Loading branch information
narendasan committed Jul 26, 2022
1 parent b26d768 commit b2a5183
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 21 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ workflows:
requires:
- build-x86_64-pyt-release

- test-py-ts-x86_64:
- test-py-fx-x86_64:
name: test-py-fx-x86_64-pyt-release
channel: "release"
torch-build: << pipeline.parameters.torch-release-build >>
Expand Down Expand Up @@ -752,7 +752,7 @@ workflows:
requires:
- build-x86_64-pyt-release

- test-py-ts-x86_64:
- test-py-fx-x86_64:
name: test-py-fx-x86_64-pyt-release
channel: "release"
torch-build: << pipeline.parameters.torch-release-build >>
Expand Down
7 changes: 3 additions & 4 deletions cpp/include/torch_tensorrt/torch_tensorrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ struct TORCHTRT_API CompileSpec {
CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);

/**
* @brief Construct a new Extra Info object
* @brief Construct a new Compile Spec object
* Convienence constructor to set fixed input size from c10::ArrayRef's (the
* output of tensor.sizes()) describing size of input tensors. Each entry in
* the vector represents a input and should be provided in call order.
Expand All @@ -583,7 +583,7 @@ struct TORCHTRT_API CompileSpec {
CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);

/**
* @brief Construct a new Extra Info object from input ranges.
* @brief Construct a new Compile Spec object from input ranges.
* Each entry in the vector represents a input and should be provided in call
* order.
*
Expand All @@ -594,8 +594,7 @@ struct TORCHTRT_API CompileSpec {
CompileSpec(std::vector<Input> inputs);

/**
* @brief Construct a new Extra Info object from IValue.
* The IValue store a complex Input
* @brief Construct a new Compile Spec object from IValue which represents the nesting of input tensors for a module.
*
* @param input_signature
*/
Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ test_suite(
":test_serialization",
":test_module_fallback",
":test_example_tensors",
":test_collection"
":test_collections"
],
)

Expand All @@ -34,7 +34,7 @@ test_suite(
":test_serialization",
":test_module_fallback",
":test_example_tensors",
":test_collection"
":test_collections"
],
)

Expand Down Expand Up @@ -125,8 +125,8 @@ cc_test(
)

cc_test(
name = "test_collection",
srcs = ["test_collection.cpp"],
name = "test_collections",
srcs = ["test_collections.cpp"],
data = [
"//tests/modules:jit_models",
],
Expand Down
12 changes: 6 additions & 6 deletions tests/cpp/test_collection.cpp → tests/cpp/test_collections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

TEST(CppAPITests, TestCollectionStandardTensorInput) {

std::string path = "tests/modules/standard_tensor_input.jit.pt";
std::string path = "tests/modules/standard_tensor_input_scripted.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
std::vector<at::Tensor> inputs;
inputs.push_back(in0);
Expand Down Expand Up @@ -53,7 +53,7 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) {

TEST(CppAPITests, TestCollectionTupleInput) {

std::string path = "tests/modules/tuple_input.jit.pt";
std::string path = "tests/modules/tuple_input_scripted.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);

torch::jit::Module mod;
Expand Down Expand Up @@ -103,7 +103,7 @@ TEST(CppAPITests, TestCollectionTupleInput) {

TEST(CppAPITests, TestCollectionListInput) {

std::string path = "tests/modules/list_input.jit.pt";
std::string path = "tests/modules/list_input_scripted.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
std::vector<at::Tensor> inputs;
inputs.push_back(in0);
Expand Down Expand Up @@ -169,7 +169,7 @@ TEST(CppAPITests, TestCollectionListInput) {

TEST(CppAPITests, TestCollectionTupleInputOutput) {

std::string path = "tests/modules/tuple_input_output.jit.pt";
std::string path = "tests/modules/tuple_input_output_scripted.jit.pt";

torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);

Expand Down Expand Up @@ -224,7 +224,7 @@ TEST(CppAPITests, TestCollectionTupleInputOutput) {

TEST(CppAPITests, TestCollectionListInputOutput) {

std::string path = "tests/modules/list_input_output.jit.pt";
std::string path = "tests/modules/list_input_output_scripted.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
std::vector<at::Tensor> inputs;
inputs.push_back(in0);
Expand Down Expand Up @@ -295,7 +295,7 @@ TEST(CppAPITests, TestCollectionListInputOutput) {

TEST(CppAPITests, TestCollectionComplexModel) {

std::string path = "tests/modules/complex_model.jit.pt";
std::string path = "tests/modules/list_input_tuple_output_scripted.jit.pt";
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
std::vector<at::Tensor> inputs;
inputs.push_back(in0);
Expand Down
4 changes: 3 additions & 1 deletion tests/cpp/test_example_tensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ TEST_P(CppAPITests, InputsFromTensors) {
trt_inputs_ivalues.push_back(in.clone());
}

auto spec = torch_tensorrt::ts::CompileSpec({trt_inputs_ivalues[0].toTensor()});

auto inputs = std::vector<torch_tensorrt::Input>{trt_inputs_ivalues[0].toTensor()};
auto spec = torch_tensorrt::ts::CompileSpec(inputs);

auto trt_mod = torch_tensorrt::ts::compile(mod, spec);
torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
Expand Down
8 changes: 4 additions & 4 deletions tests/modules/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@
"model": cm.ListInputTupleOutput(),
"path": "script"
},
"bert_base_uncased": {
"model": cm.BertModule(),
"path": "trace"
}
#"bert_base_uncased": {
# "model": cm.BertModule(),
# "path": "trace"
#}
}


Expand Down

0 comments on commit b2a5183

Please sign in to comment.