diff --git a/py/setup.py b/py/setup.py index b870560ae5..eb382559f8 100644 --- a/py/setup.py +++ b/py/setup.py @@ -427,6 +427,8 @@ def run(self): ext_modules=ext_modules, install_requires=[ "torch >=2.1.dev,<2.2" if not LEGACY else "torch >=1.13.0,<2.0", + "pyyaml", + "packaging", ], setup_requires=[], cmdclass={ diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 8f6408492a..ea0d398a44 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -8,6 +8,9 @@ from torch_tensorrt.dynamo.backend.lowering._decompositions import ( get_decompositions, ) +from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import ( + pre_aot_module_replacement, +) from torch_tensorrt.dynamo.backend.lowering._partition import ( partition, get_submod_inputs, @@ -46,6 +49,13 @@ def aot_torch_tensorrt_aten_backend( settings=settings, ) + logger.debug("Pre-module replacement graph:\n" + str(gm.graph)) + + # Enable Pre-AOT Lowering for Module-Level Replacement + gm = pre_aot_module_replacement(gm) + + logger.debug("Post-module replacement graph:\n" + str(gm.graph)) + # Invoke AOTAutograd to translate operators to aten return aot_module_simplified( gm, @@ -71,6 +81,8 @@ def _pretraced_backend( Compiled FX GraphModule """ try: + logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) + trt_compiled = _compile_module( gm, sample_inputs, diff --git a/py/torch_tensorrt/dynamo/backend/lowering/__init__.py b/py/torch_tensorrt/dynamo/backend/lowering/__init__.py index 01b20cef6d..1a0cbab2df 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/__init__.py @@ -1,7 +1,9 @@ -from torch_tensorrt.dynamo.backend.lowering._decompositions import ( +from ._decompositions import ( get_decompositions, ) -from torch_tensorrt.dynamo.backend.lowering._partition import ( - partition, - get_submod_inputs, +from ._pre_aot_lowering import ( + MODULE_SUBSTITUTION_REGISTRY, + module_substitution, ) +from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS +from .module_substitutions import * diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 5cd83d768c..496c91a089 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -1,9 +1,10 @@ import logging -from typing import Dict, List, Optional, Sequence +from typing import Dict, List, Optional, Sequence, Set import torch from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE +from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.graph_module import GraphModule from torch.fx.node import _get_qualified_name @@ -14,6 +15,11 @@ logger = logging.getLogger(__name__) +DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set( + _get_qualified_name(module.new_operator) + for module in MODULE_SUBSTITUTION_REGISTRY.values() +) + class TRTPartitioner(CapabilityBasedPartitioner): """Partitioner to split an FX graph into subgraphs based on operator support @@ -35,7 +41,9 @@ def __init__( operator_support: OperatorSupport, *, non_compute_ops: Optional[Sequence[str]] = None, - allowed_single_node_partition_ops: Optional[Sequence[str]] = None, + allowed_single_node_partition_ops: Optional[ + Sequence[str] + ] = DEFAULT_SINGLE_NODE_PARTITIONS, min_block_size=MIN_BLOCK_SIZE, ) -> None: super().__init__( diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py new file mode 100644 index 0000000000..738a398a51 --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py @@ -0,0 +1,121 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, Type +import torch +import logging + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ModuleReplacement: + """Class to store key functionality for module replacement""" + + # torch.ops.___ name for replacement function for module + new_operator: torch._ops.OpOverload + + # Function taking a containing graph, a submodule, and a 'call_module' node and returning + # a replacement node, with type 'call_function', or raising an Error if incompatibility is detected + # Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph + subgraph_insertion_fn: Callable[ + [torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node + ] + + +# Dictionary mapping module to ModuleReplacement instance +MODULE_SUBSTITUTION_REGISTRY: Dict[Type[torch.nn.Module], ModuleReplacement] = dict() + + +def module_substitution( + module_to_replace: Type[torch.nn.Module], + new_operator: torch._ops.OpOverload, + enabled: bool = True, +) -> Callable[[Any], Any]: + """Decorator to register subgraph insertion functions + + Args: + module_to_replace: nn.Module to replace + new_operator: Custom torch operator to replace with + enabled: Whether the substitution is enabled or disabled + Returns: + torch.fx.GraphModule + """ + + def register_substitution(subgraph_insertion_fn): + """Function for use if substitution is enabled""" + module_replacement = ModuleReplacement( + new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn + ) + MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement + return subgraph_insertion_fn + + def disable_substitution(subgraph_insertion_fn): + """Function for use if substitution is disabled""" + return subgraph_insertion_fn + + return register_substitution if enabled else disable_substitution + + +def pre_aot_module_replacement(gm: torch.fx.GraphModule): + """Perform module-level graph replacement prior to AOT tracing + + Args: + gm: FX GraphModule to perform module replacement on + Returns: + torch.fx.GraphModule + + """ + # Ensure all parameters are in inference mode + for param in gm.parameters(): + param.requires_grad = False + + # Iterate over graph nodes, extracting module calls, to check for interceptions + for n in gm.graph.nodes: + if n.op == "call_module": + # Extract submodule from graph + submodule = gm.get_submodule(n.target) + + # If submodule is a member of the substitution registry, replace it + if type(submodule) in MODULE_SUBSTITUTION_REGISTRY: + + try: + replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)] + op, insertion_fn = ( + replacement.new_operator, + replacement.subgraph_insertion_fn, + ) + logger.debug( + f"Replacing module of type {type(submodule)} with {op}" + ) + + # Insert new node prior to older node + with gm.graph.inserting_before(n): + new_node = insertion_fn(gm, submodule, n) + + # If submodule is not a native torch.nn module, it must be manually excluded + # from Dynamo tracing + if not type(submodule).__module__.startswith("torch.nn"): + torch._dynamo.allowed_functions._allowed_function_ids.add( + id(type(submodule)) + ) + + # Replace all original node uses and clean up graph + n.replace_all_uses_with(new_node) + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + # A module replacement can fail in the event that the specific instance of the submodule cannot + # be replaced + except Exception: + logger.debug( + f"Encountered error while replacing {type(submodule)}", + exc_info=True, + ) + continue + + # Perform cleanup and recompilation before returning module + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + return gm diff --git a/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py b/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py new file mode 100644 index 0000000000..4b8ba88e34 --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py @@ -0,0 +1 @@ +from .maxpool1d import * diff --git a/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py b/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py new file mode 100644 index 0000000000..a45ad146e7 --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py @@ -0,0 +1,126 @@ +from typing import Dict, Tuple +import torch +from torch._custom_op import custom_op +from torch.fx.node import Argument, Target + +from torch_tensorrt.fx.converter_registry import tensorrt_converter +from torch_tensorrt.fx.converters import acc_ops_converters +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +from torch_tensorrt.dynamo.backend.lowering import module_substitution + + +# This file serves as an example and a tutorial for excluding custom modules from +# torch.compile tracing. Each required step is labeled with a number indicating the +# preferable implementation order. + + +# 1. The Placeholder +# +# Specify the schema and namespace of the operator, as well as a placeholder function +# representing the schema. The schema should be in torch JIT syntax, indicating input and output +# types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op +# Then, create a placeholder function with no operations, but having the same schema and naming as that +# used in the decorator +@custom_op( + "(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor", + ns="tensorrt", +) +def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): + # Defines operator schema, name, namespace, and function header + ... + + +# 2. The Generic Implementation +# +# Define the default implementation of the operator in torch syntax. This is used for autograd +# and other tracing functionality. Generally, the torch.nn.functional analog of the operator to replace +# is desirable. If the operator to replace is a custom module you've written, then add its Torch +# implementation here. Note that the function header to the generic function can have specific arguments +# as in the above placeholder +@maxpool1d.impl("cpu") +@maxpool1d.impl("cuda") +def maxpool1d_generic( + *args, + **kwargs, +): + # Defines an implementation for AOT Autograd to use for shape analysis/propagation + return torch.nn.functional.max_pool1d( + *args, + **kwargs, + ) + + +# 3. The Module Substitution Function +# +# Define a function which can intercept a node of the kind to be replaced, extract +# the relevant data from that node/submodule, and then re-package the information +# for use by an accelerated implementation (to be implemented in step 4). This function +# should use the operator defined in step 1 (for example torch.ops.tensorrt.maxpool1d). +# It should refactor the args and kwargs as is needed by the accelerated implementation. +# +# If the submodule has weights or other Tensor fields which the accelerated implementation +# needs, the function should insert the necessary nodes to access those weights. For example, +# if the weight Tensor of a submodule is needed, one could write: +# +# weights = gm.graph.get_attr(n.target + ".weight", torch.Tensor) +# bias = gm.graph.get_attr(n.target + ".bias", torch.Tensor) +# ... +# kwargs={"weight": weights, +# "bias": bias, +# ... +# +@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) +def maxpool1d_insertion_fn( + gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node +) -> torch.fx.Node: + # Defines insertion function for new node + new_node = gm.graph.call_function( + torch.ops.tensorrt.maxpool1d, + args=node.args, + kwargs={ + "kernel_size": submodule.kernel_size, + "stride": submodule.stride, + "padding": submodule.padding, + "dilation": submodule.dilation, + "ceil_mode": submodule.ceil_mode, + }, + ) + + return new_node + + +# 4. The Accelerated Implementation +# +# Define an accelerated implementation of the operator, and register it as necessary. +# This accelerated implementation should consume the args/kwargs specified in step 3. +# One should expect that torch.compile will compress all kwargs into the args field in +# the order specified in the schema written in step 1. +@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) +def aten_ops_maxpool1d( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + # Defines converter replacing the default operator for this function + kwargs_new = { + "input": args[0], + "kernel_size": args[1], + "stride": args[2], + "padding": args[3], + "dilation": args[4], + "ceil_mode": False if len(args) < 6 else args[5], + } + + return acc_ops_converters.acc_ops_max_pool1d( + network, target, None, kwargs_new, name + ) + + +# 5. Add Imports +# +# Add your accelerated module file to the __init__.py in this directory, to ensure +# all registrations are run. For instance, if the new module file is called new_mod.py, +# one should add `from .new_mod import *` to the __init__.py diff --git a/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py new file mode 100644 index 0000000000..2fa65bfabc --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py @@ -0,0 +1,55 @@ +import torch +from utils import lower_graph_testing +from torch.testing._internal.common_utils import run_tests, TestCase +from torch_tensorrt.dynamo import compile + + +class TestMaxPool1D(TestCase): + def test_pre_aot_lowering_maxpool1d(self): + class MaxPool1D(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.maxpool = torch.nn.MaxPool1d(2) + + def forward(self, x): + return self.maxpool(x) + + # Operations expected to be included in the traced graph after decompositions + expected_ops = {torch.ops.tensorrt.maxpool1d.default} + + inputs = [ + torch.rand( + 9, + 16, + 2, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(MaxPool1D()) + _, expected_ops_unseen = lower_graph_testing( + fx_graph, inputs, expected_ops=expected_ops, min_block_size=1 + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = compile( + fx_graph, inputs, min_block_size=1, pass_through_build_failures=True + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results)) + self.assertAlmostEqual( + max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model." + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/backend/test/utils.py b/py/torch_tensorrt/dynamo/backend/test/utils.py index 48f6443e32..e7dc435ac4 100644 --- a/py/torch_tensorrt/dynamo/backend/test/utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/utils.py @@ -8,6 +8,9 @@ from torch_tensorrt.dynamo.backend.lowering._partition import ( partition, ) +from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import ( + pre_aot_module_replacement, +) from torch._dynamo.backends.common import fake_tensor_unsupported @@ -31,6 +34,8 @@ def fx_dynamo_testing_backend( torch_executed_ops=torch_executed_ops, ) + gm = pre_aot_module_replacement(gm) + # Invoke AOTAutograd to translate operators to aten return aot_module_simplified( gm,