From 19ecc640664b7bca394edd2e7c305cf2dab83dc3 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Mon, 18 Oct 2021 08:27:51 -0700 Subject: [PATCH] refactor!(//cpp): Inlining partial compilation settings since the feature is now on by default BREAKING CHANGE: This commit changes the API for automatic fallback to inline settings regarding partial compilation in preparation for it to be turned on by default Now in the compile spec instead of a `torch_fallback` field with its associated struct, there are four new fields in the compile spec ```c++ bool require_full_compilation = true; uint64_t min_block_size = 3; std::vector torch_executed_ops = {}; std::vector torch_executed_modules = {}; ``` Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- cpp/include/trtorch/trtorch.h | 58 +++++++++++------------------- cpp/src/compile_spec.cpp | 32 +++++++---------- tests/core/BUILD | 31 ++++++++++++++++ tests/cpp/test_module_fallback.cpp | 8 ++--- 4 files changed, 68 insertions(+), 61 deletions(-) diff --git a/cpp/include/trtorch/trtorch.h b/cpp/include/trtorch/trtorch.h index 08f1f79fe3..0c298a3321 100644 --- a/cpp/include/trtorch/trtorch.h +++ b/cpp/include/trtorch/trtorch.h @@ -516,38 +516,6 @@ struct TRTORCH_API CompileSpec { bool explicit_set_dtype; }; - /** - * @brief A struct to hold fallback info - */ - struct TRTORCH_API TorchFallback { - /// enable the automatic fallback feature - bool enabled = false; - - /// minimum consecutive operation number that needs to be satisfied to convert to TensorRT - uint64_t min_block_size = 1; - - /// A list of names of operations that will explicitly run in PyTorch - std::vector forced_fallback_ops; - - /// A list of names of modules that will explicitly run in PyTorch - std::vector forced_fallback_modules; - - /** - * @brief Construct a default Torch Fallback object, fallback will be off - */ - TorchFallback() = default; - - /** - * @brief Construct from a bool - */ - TorchFallback(bool enabled) : enabled(enabled) {} - - /** - * @brief Constructor for setting min_block_size - */ - TorchFallback(bool enabled, uint64_t min_size) : enabled(enabled), min_block_size(min_size) {} - }; - /** * @brief Construct a new Extra Info object * Convienence constructor to set fixed input size from vectors describing @@ -643,11 +611,6 @@ struct TRTORCH_API CompileSpec { */ Device device; - /** - * @brief Settings related to partial compilation - */ - TorchFallback torch_fallback; - /** * Sets the restrictions for the engine (CUDA Safety) */ @@ -676,6 +639,27 @@ struct TRTORCH_API CompileSpec { * Calibration dataloaders for each input for post training quantizatiom */ nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr; + + /** + * Require the full module be compiled to TensorRT instead of potentially running unsupported operations in PyTorch + */ + bool require_full_compilation = false; + + /** + * Minimum number of contiguous supported operators to compile a subgraph to TensorRT + */ + uint64_t min_block_size = 3; + + /** + * List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True + */ + std::vector torch_executed_ops; + + + /** + * List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True + */ + std::vector torch_executed_modules; }; /** diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index e0ba2c719d..ff28bd2fe9 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -323,21 +323,6 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) { internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p)); } - // /* We want default behavior for types to match PyTorch, so in the case the user did not explicitly set the dtype - // for inputs they will follow PyTorch convetions */ for (size_t i = 0; i < external.inputs.size(); i++) { - // if (!external.inputs[i].get_explicit_set_dtype()) { - // auto& precisions = internal.convert_info.engine_settings.enabled_precisions; - // auto& internal_ins = internal.convert_info.inputs; - // if (precisions.find(nvinfer1::DataType::kINT8) != precisions.end()) { - // internal_ins[i].dtype = nvinfer1::DataType::kFLOAT; - // } else if (precisions.find(nvinfer1::DataType::kHALF) != precisions.end()) { - // internal_ins[i].dtype = nvinfer1::DataType::kHALF; - // } else { - // internal_ins[i].dtype = nvinfer1::DataType::kFLOAT; - // } - // } - // } - internal.convert_info.engine_settings.sparse_weights = external.sparse_weights; internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32; internal.convert_info.engine_settings.refit = external.refit; @@ -346,10 +331,19 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) { internal.convert_info.engine_settings.strict_types = external.strict_types; internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback; internal.convert_info.engine_settings.max_batch_size = external.max_batch_size; - internal.partition_info.enabled = external.torch_fallback.enabled; - internal.partition_info.min_block_size = external.torch_fallback.min_block_size; - internal.partition_info.forced_fallback_operators = external.torch_fallback.forced_fallback_ops; - internal.lower_info.forced_fallback_modules = external.torch_fallback.forced_fallback_modules; + + TRTORCH_CHECK(!(external.require_full_compilation && (external.torch_executed_ops.size() > 0)), + "require_full_compilation is enabled however the list of ops to run in torch is not empty (Found " + << external.torch_executed_ops.size() << " ops)"); + + TRTORCH_CHECK(!(external.require_full_compilation && (external.torch_executed_modules.size() > 0)), + "require_full_compilation is enabled however the list of modules to run in torch is not empty (Found " + << external.torch_executed_modules.size() << " modules)"); + + internal.partition_info.enabled = external.require_full_compilation; + internal.partition_info.min_block_size = external.min_block_size; + internal.partition_info.forced_fallback_operators = std::move(external.torch_executed_ops); + internal.lower_info.forced_fallback_modules = std::move(external.torch_executed_modules); switch (external.device.device_type) { case CompileSpec::Device::DeviceType::kDLA: diff --git a/tests/core/BUILD b/tests/core/BUILD index ab0d46f7d1..fc5f788a1b 100644 --- a/tests/core/BUILD +++ b/tests/core/BUILD @@ -1,6 +1,37 @@ +config_setting( + name = "use_pre_cxx11_abi", + values = { + "define": "abi=pre_cxx11_abi", + } +) + +filegroup( + name = "jit_models", + srcs = ["//tests/modules:mobilenet_v2_scripted.jit.pt"] +) + +cc_test( + name = "test_detecting_input_type", + srcs = ["test_detecting_input_type.cpp"], + deps = [ + "//tests/util", + "//core", + "//core/lowering", + "//core/util:prelude", + "@googletest//:gtest_main", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }), + data = [ + ":jit_models" + ] +) + test_suite( name = "core_tests", tests = [ + ":test_detecting_input_type", "//tests/core/conversion:conversion_tests", "//tests/core/lowering:lowering_tests", "//tests/core/partitioning:partitioning_tests" diff --git a/tests/cpp/test_module_fallback.cpp b/tests/cpp/test_module_fallback.cpp index f6f101b7db..530ecdbac3 100644 --- a/tests/cpp/test_module_fallback.cpp +++ b/tests/cpp/test_module_fallback.cpp @@ -23,8 +23,7 @@ TEST(CppAPITest, ResNetModuleFallbacksCorrectly) { } trtorch::CompileSpec cfg(input_shapes); - cfg.torch_fallback.enabled = true; - cfg.torch_fallback.forced_fallback_modules.push_back("torchvision.models.resnet.BasicBlock"); + cfg.torch_executed_modules.push_back("torchvision.models.resnet.BasicBlock"); auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); auto trt_mod = trtorch::CompileGraph(mod, cfg); @@ -51,9 +50,8 @@ TEST(CppAPITest, MobileNetModuleFallbacksCorrectlyWithOneEngine) { } trtorch::CompileSpec cfg(input_shapes); - cfg.torch_fallback.enabled = true; - cfg.torch_fallback.min_block_size = 5; - cfg.torch_fallback.forced_fallback_modules.push_back("torchvision.models.mobilenetv2.ConvBNActivation"); + cfg.min_block_size = 5; + cfg.torch_executed_modules.push_back("torchvision.models.mobilenetv2.ConvBNActivation"); auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); auto trt_mod = trtorch::CompileGraph(mod, cfg);