Skip to content

Commit

Permalink
Merge pull request #616 from NVIDIA/example_tensors
Browse files Browse the repository at this point in the history
Example tensors
  • Loading branch information
narendasan authored Sep 13, 2021
2 parents e95aa99 + 122429f commit 0ec2eb3
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 6 deletions.
16 changes: 13 additions & 3 deletions cpp/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ struct TRTORCH_API CompileSpec {
Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);

/**
* @brief Construct a new Input Range object dynamic input size from
* @brief Construct a new Input spec object dynamic input size from
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
Expand Down Expand Up @@ -462,7 +462,7 @@ struct TRTORCH_API CompileSpec {
TensorFormat format = TensorFormat::kContiguous);

/**
* @brief Construct a new Input Range object dynamic input size from
* @brief Construct a new Input spec object dynamic input size from
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
Expand All @@ -479,7 +479,7 @@ struct TRTORCH_API CompileSpec {
TensorFormat format = TensorFormat::kContiguous);

/**
* @brief Construct a new Input Range object dynamic input size from
* @brief Construct a new Input spec object dynamic input size from
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
* supported sizes
*
Expand All @@ -496,6 +496,16 @@ struct TRTORCH_API CompileSpec {
DataType dtype,
TensorFormat format = TensorFormat::kContiguous);

/**
* @brief Construct a new Input spec object using a torch tensor as an example
* The tensor's shape, type and layout inform the spec's values
*
* Note: You cannot set dynamic shape through this method, you must use an alternative constructor
*
* @param tensor Reference tensor to set shape, type and layout
*/
Input(at::Tensor tensor);

bool get_explicit_set_dtype() {
return explicit_set_dtype;
}
Expand Down
20 changes: 20 additions & 0 deletions cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,26 @@ CompileSpec::Input::Input(
this->input_is_dynamic = true;
}

CompileSpec::Input::Input(at::Tensor tensor) {
this->opt_shape = tensor.sizes().vec();
this->min_shape = tensor.sizes().vec();
this->max_shape = tensor.sizes().vec();
this->shape = tensor.sizes().vec();
this->dtype = tensor.scalar_type();
this->explicit_set_dtype = true;
TRTORCH_ASSERT(
tensor.is_contiguous(at::MemoryFormat::ChannelsLast) || tensor.is_contiguous(at::MemoryFormat::Contiguous),
"Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last");
at::MemoryFormat frmt;
if (tensor.is_contiguous(at::MemoryFormat::Contiguous)) {
frmt = at::MemoryFormat::Contiguous;
} else {
frmt = at::MemoryFormat::ChannelsLast;
}
this->format = frmt;
this->input_is_dynamic = false;
}

/* ==========================================*/

core::ir::Input to_internal_input(CompileSpec::InputRange& i) {
Expand Down
13 changes: 13 additions & 0 deletions py/trtorch/Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,16 @@ def _parse_format(format: Any) -> _types.TensorFormat:
else:
raise TypeError(
"Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat")

@classmethod
def _from_tensor(cls, t: torch.Tensor):
if not any([
t.is_contiguous(memory_format=torch.contiguous_format),
t.is_contiguous(memory_format=torch.channels_last)
]):
raise ValueError(
"Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last"
)
frmt = torch.contiguous_format if t.is_contiguous(
memory_format=torch.contiguous_format) else torch.channels_last
return cls(shape=t.shape, dtype=t.dtype, format=frmt)
7 changes: 6 additions & 1 deletion py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,12 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
info.inputs = _parse_input_ranges(compile_spec["input_shapes"])

if "inputs" in compile_spec:
info.inputs = [i._to_internal() for i in compile_spec["inputs"]]
if not all([isinstance(i, torch.Tensor) or isinstance(i, trtorch.Input) for i in compile_spec["inputs"]]):
raise KeyError("Input specs should be either trtorch.Input or torch.Tensor, found types: {}".format(
[typeof(i) for i in compile_spec["inputs"]]))

inputs = [trtorch.Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]]
info.inputs = [i._to_internal() for i in inputs]

if "op_precision" in compile_spec and "enabled_precisions" in compile_spec:
raise KeyError(
Expand Down
17 changes: 15 additions & 2 deletions tests/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ test_suite(
":test_modules_as_engines",
":test_multiple_registered_engines",
":test_serialization",
":test_module_fallback"
":test_module_fallback",
":test_example_tensors"
],
)

Expand All @@ -28,7 +29,8 @@ test_suite(
":test_modules_as_engines",
":test_multiple_registered_engines",
":test_serialization",
":test_module_fallback"
":test_module_fallback",
":test_example_tensors"
],
)

Expand All @@ -43,6 +45,17 @@ cc_test(
],
)

cc_test(
name = "test_example_tensors",
srcs = ["test_example_tensors.cpp"],
data = [
"//tests/modules:jit_models",
],
deps = [
":cpp_api_test",
],
)

cc_test(
name = "test_serialization",
srcs = ["test_serialization.cpp"],
Expand Down
23 changes: 23 additions & 0 deletions tests/cpp/test_example_tensors.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "cpp_api_test.h"

TEST_P(CppAPITests, InputsFromTensors) {
std::vector<torch::jit::IValue> jit_inputs_ivalues;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
for (auto in_shape : input_shapes) {
auto in = at::randn(in_shape, {at::kCUDA});
jit_inputs_ivalues.push_back(in.clone());
trt_inputs_ivalues.push_back(in.clone());
}

auto spec = trtorch::CompileSpec({trt_inputs_ivalues[0].toTensor()});

auto trt_mod = trtorch::CompileGraph(mod, spec);
torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
std::vector<at::Tensor> trt_results;
trt_results.push_back(trt_results_ivalues.toTensor());
}

INSTANTIATE_TEST_SUITE_P(
CompiledModuleForwardIsCloseSuite,
CppAPITests,
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5})));
21 changes: 21 additions & 0 deletions tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,27 @@ def test_compile_script(self):
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

def test_from_torch_tensor(self):
compile_spec = {
"inputs": [self.input],
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
},
"enabled_precisions": {torch.float}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

def test_device(self):
compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)


class TestCompileHalf(ModelTestCase):

Expand Down

0 comments on commit 0ec2eb3

Please sign in to comment.