diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d213cca638..6a95151e2f 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -375,6 +375,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, ) except torch.fx.passes.splitter_base.FxNetSplitterInternalError: logger.error( @@ -393,6 +394,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, ) dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 605d963a50..90e5159398 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -107,6 +107,10 @@ def _pretraced_backend( torchtrt_inputs = prepare_inputs( torch_inputs, disable_memory_format_check=True ) + if settings.require_full_compilation: + logger.warning( + "require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt" + ) trt_compiled = compile_module( gm, torchtrt_inputs,