diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index 6247373b1f..6bf1a6d194 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -16,6 +16,7 @@ DEBUG, MAX_WORKSPACE_SIZE, MIN_BLOCK_SIZE, + PASS_THROUGH_BUILD_FAILURES, ) @@ -52,7 +53,8 @@ def compile( logger.warn( "The Dynamo backend is an experimental feature, for which only the " + "following arguments are supported: " - + "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}" + + "{enabled_precisions, debug, workspace_size, min_block_size, " + + "torch_executed_ops, pass_through_build_failures}" ) if not isinstance(inputs, collections.abc.Sequence): @@ -106,6 +108,7 @@ def create_backend( workspace_size: int = MAX_WORKSPACE_SIZE, min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Sequence[str] = set(), + pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, **kwargs, ): """Create torch.compile backend given specified arguments @@ -124,6 +127,7 @@ def create_backend( workspace_size=workspace_size, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, + pass_through_build_failures=pass_through_build_failures, ) return partial( diff --git a/py/torch_tensorrt/dynamo/backend/_defaults.py b/py/torch_tensorrt/dynamo/backend/_defaults.py index b1ee62dfa3..fe7b5f6b4f 100644 --- a/py/torch_tensorrt/dynamo/backend/_defaults.py +++ b/py/torch_tensorrt/dynamo/backend/_defaults.py @@ -5,3 +5,4 @@ DEBUG = False MAX_WORKSPACE_SIZE = 20 << 30 MIN_BLOCK_SIZE = 5 +PASS_THROUGH_BUILD_FAILURES = False diff --git a/py/torch_tensorrt/dynamo/backend/_settings.py b/py/torch_tensorrt/dynamo/backend/_settings.py index 8c1a807343..df3212f54a 100644 --- a/py/torch_tensorrt/dynamo/backend/_settings.py +++ b/py/torch_tensorrt/dynamo/backend/_settings.py @@ -7,6 +7,7 @@ DEBUG, MAX_WORKSPACE_SIZE, MIN_BLOCK_SIZE, + PASS_THROUGH_BUILD_FAILURES, ) @@ -17,3 +18,4 @@ class CompilationSettings: workspace_size: int = MAX_WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE torch_executed_ops: Sequence[str] = field(default_factory=set) + pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 4c2c5fdcc4..8f6408492a 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -1,6 +1,6 @@ +import logging from typing import Sequence import torch -import traceback from functools import partial import torch._dynamo as td @@ -19,6 +19,9 @@ from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler +logger = logging.getLogger(__name__) + + @td.register_backend(name="torch_tensorrt") @fake_tensor_unsupported def torch_tensorrt_backend( @@ -75,12 +78,22 @@ def _pretraced_backend( ) return trt_compiled except: - traceback.print_exc() - print( + logger.error( "FX2TRT conversion failed on the subgraph. See trace above. " - + "Returning GraphModule forward instead." + + "Returning GraphModule forward instead.", + exc_info=True, ) - return gm.forward + + if not settings.pass_through_build_failures: + return gm.forward + else: + raise AssertionError( + "Halting compilation on build failure since " + + "pass_through_build_failures was specified as True. " + + "To return the default Torch implementation and avoid " + + "halting compilation on engine build failures, " + + "specify pass_through_build_failures=False." + ) def _compile_module( diff --git a/py/torch_tensorrt/dynamo/backend/test/utils.py b/py/torch_tensorrt/dynamo/backend/test/utils.py index d59b710faf..48f6443e32 100644 --- a/py/torch_tensorrt/dynamo/backend/test/utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/utils.py @@ -124,7 +124,7 @@ def lower_graph_testing( torch_executed_ops: Sequence[str] = set(), testing_partitioning: bool = False, ): - """Helper function to assist with graph lowering for testing of Dynamo torch_compile + """Helper function to assist with graph lowering for testing of Dynamo compile Args: fx_graph: Graph to lower diff --git a/py/torch_tensorrt/dynamo/common_utils/__init__.py b/py/torch_tensorrt/dynamo/common_utils/__init__.py new file mode 100644 index 0000000000..865d0d8d3a --- /dev/null +++ b/py/torch_tensorrt/dynamo/common_utils/__init__.py @@ -0,0 +1 @@ +from .test_utils import * diff --git a/py/torch_tensorrt/dynamo/test/utils.py b/py/torch_tensorrt/dynamo/common_utils/test_utils.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/utils.py rename to py/torch_tensorrt/dynamo/common_utils/test_utils.py diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py index b86817df56..e6af03ed46 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -7,7 +7,10 @@ from transformers import BertModel -from utils import COSINE_THRESHOLD, cosine_similarity +from torch_tensorrt.dynamo.common_utils.test_utils import ( + COSINE_THRESHOLD, + cosine_similarity, +) @pytest.mark.unit @@ -24,13 +27,14 @@ def test_resnet18(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, + "pass_through_build_failures": True, } trt_mod = torchtrt.compile(model, **compile_spec) cos_sim = cosine_similarity(model(input), trt_mod(input)) assert ( cos_sim > COSINE_THRESHOLD, - f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Clean up model env @@ -54,6 +58,7 @@ def test_mobilenet_v2(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, + "pass_through_build_failures": True, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -84,6 +89,7 @@ def test_efficientnet_b0(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, + "pass_through_build_failures": True, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -123,6 +129,7 @@ def test_bert_base_uncased(ir): "enabled_precisions": {torch.float}, "truncate_long_and_double": True, "ir": ir, + "pass_through_build_failures": True, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -157,13 +164,14 @@ def test_resnet18_half(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.half}, "ir": ir, + "pass_through_build_failures": True, } trt_mod = torchtrt.compile(model, **compile_spec) cos_sim = cosine_similarity(model(input), trt_mod(input)) assert ( cos_sim > COSINE_THRESHOLD, - f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Clean up model env