From e1d2b4c5ed84dd142d36f741e1442fc78712752b Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 27 Sep 2024 00:36:00 -0700 Subject: [PATCH 1/2] add req_full_compilation_arg add arg require_full_compilation --- py/torch_tensorrt/dynamo/_compiler.py | 2 ++ py/torch_tensorrt/dynamo/backend/backends.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 97aa2ec443..f473ee11c4 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -376,6 +376,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, ) except torch.fx.passes.splitter_base.FxNetSplitterInternalError: logger.error( @@ -394,6 +395,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, ) dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 605d963a50..3c06c3d635 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -107,6 +107,10 @@ def _pretraced_backend( torchtrt_inputs = prepare_inputs( torch_inputs, disable_memory_format_check=True ) + if settings.require_full_compilation: + logger.warning( + "This argument is not applicable for torch.compile with backend='torch_tensorrt" + ) trt_compiled = compile_module( gm, torchtrt_inputs, From cf2b11be58f9dbcf64932e321fd12913c3db909c Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 27 Sep 2024 00:36:00 -0700 Subject: [PATCH 2/2] add req_full_compilation_arg add arg require_full_compilation --- py/torch_tensorrt/dynamo/backend/backends.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 3c06c3d635..90e5159398 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -109,7 +109,7 @@ def _pretraced_backend( ) if settings.require_full_compilation: logger.warning( - "This argument is not applicable for torch.compile with backend='torch_tensorrt" + "require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt" ) trt_compiled = compile_module( gm,