Skip to content

Commit

Permalink
feat(//cpp): Adding example tensors as a way to set input spec
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Aug 31, 2021
1 parent 01d525d commit 70a7bb3
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 5 deletions.
16 changes: 13 additions & 3 deletions cpp/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ struct TRTORCH_API CompileSpec {
Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);

/**
* @brief Construct a new Input Range object dynamic input size from
* @brief Construct a new Input spec object dynamic input size from
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
Expand Down Expand Up @@ -462,7 +462,7 @@ struct TRTORCH_API CompileSpec {
TensorFormat format = TensorFormat::kContiguous);

/**
* @brief Construct a new Input Range object dynamic input size from
* @brief Construct a new Input spec object dynamic input size from
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
Expand All @@ -479,7 +479,7 @@ struct TRTORCH_API CompileSpec {
TensorFormat format = TensorFormat::kContiguous);

/**
* @brief Construct a new Input Range object dynamic input size from
* @brief Construct a new Input spec object dynamic input size from
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
* supported sizes
*
Expand All @@ -496,6 +496,16 @@ struct TRTORCH_API CompileSpec {
DataType dtype,
TensorFormat format = TensorFormat::kContiguous);

/**
* @brief Construct a new Input spec object using a torch tensor as an example
* The tensor's shape, type and layout inform the spec's values
*
* Note: You cannot set dynamic shape through this method, you must use an alternative constructor
*
* @param tensor Reference tensor to set shape, type and layout
*/
Input(at::Tensor tensor);

bool get_explicit_set_dtype() {
return explicit_set_dtype;
}
Expand Down
18 changes: 18 additions & 0 deletions cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,24 @@ CompileSpec::Input::Input(
this->input_is_dynamic = true;
}

CompileSpec::Input::Input(at::Tensor tensor) {
this->opt_shape = tensor.sizes().vec();
this->min_shape = tensor.sizes().vec();
this->max_shape = tensor.sizes().vec();
this->shape = tensor.sizes().vec();
this->dtype = tensor.scalar_type();
this->explicit_set_dtype = true;
TRTORCH_ASSERT(tensor.is_contiguous(at::MemoryFormat::ChannelsLast) || tensor.is_contiguous(at::MemoryFormat::Contiguous), "Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last");
at::MemoryFormat frmt;
if (tensor.is_contiguous(at::MemoryFormat::Contiguous)) {
frmt = at::MemoryFormat::Contiguous;
} else {
frmt = at::MemoryFormat::ChannelsLast;
}
this->format = frmt;
this->input_is_dynamic = false;
}

/* ==========================================*/

core::ir::Input to_internal_input(CompileSpec::InputRange& i) {
Expand Down
17 changes: 15 additions & 2 deletions tests/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ test_suite(
":test_modules_as_engines",
":test_multiple_registered_engines",
":test_serialization",
":test_module_fallback"
":test_module_fallback",
":test_example_tensors"
],
)

Expand All @@ -28,7 +29,8 @@ test_suite(
":test_modules_as_engines",
":test_multiple_registered_engines",
":test_serialization",
":test_module_fallback"
":test_module_fallback",
":test_example_tensors"
],
)

Expand All @@ -43,6 +45,17 @@ cc_test(
],
)

cc_test(
name = "test_example_tensors",
srcs = ["test_example_tensors.cpp"],
data = [
"//tests/modules:jit_models",
],
deps = [
":cpp_api_test",
],
)

cc_test(
name = "test_serialization",
srcs = ["test_serialization.cpp"],
Expand Down
24 changes: 24 additions & 0 deletions tests/cpp/test_example_tensors.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "cpp_api_test.h"

TEST_P(CppAPITests, InputsFromTensors) {
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randn(in_shape, {at::kCUDA});
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}

auto spec = trtorch::CompileSpec({trt_inputs_ivalues[0].toTensor()});

auto trt_mod = trtorch::CompileGraph(mod, spec);
torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());
}

INSTANTIATE_TEST_SUITE_P(
CompiledModuleForwardIsCloseSuite,
CppAPITests,
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5})));

0 comments on commit 70a7bb3

Please sign in to comment.