-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Prototype Module-Acceleration in Dynamo #1921
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from torch_tensorrt.dynamo.backend.lowering._decompositions import ( | ||
from ._decompositions import ( | ||
get_decompositions, | ||
) | ||
from torch_tensorrt.dynamo.backend.lowering._partition import ( | ||
partition, | ||
get_submod_inputs, | ||
from ._pre_aot_lowering import ( | ||
MODULE_SUBSTITUTION_REGISTRY, | ||
module_substitution, | ||
) | ||
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS | ||
from .module_substitutions import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Dict, Type | ||
import torch | ||
import logging | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ModuleReplacement: | ||
"""Class to store key functionality for module replacement""" | ||
|
||
# torch.ops.___ name for replacement function for module | ||
new_operator: torch._ops.OpOverload | ||
|
||
# Function taking a containing graph, a submodule, and a 'call_module' node and returning | ||
# a replacement node, with type 'call_function', or raising an Error if incompatibility is detected | ||
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph | ||
subgraph_insertion_fn: Callable[ | ||
[torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node | ||
] | ||
|
||
|
||
# Dictionary mapping module to ModuleReplacement instance | ||
MODULE_SUBSTITUTION_REGISTRY: Dict[Type[torch.nn.Module], ModuleReplacement] = dict() | ||
|
||
|
||
def module_substitution( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is sick 😄 |
||
module_to_replace: Type[torch.nn.Module], | ||
new_operator: torch._ops.OpOverload, | ||
enabled: bool = True, | ||
) -> Callable[[Any], Any]: | ||
"""Decorator to register subgraph insertion functions | ||
|
||
Args: | ||
module_to_replace: nn.Module to replace | ||
new_operator: Custom torch operator to replace with | ||
enabled: Whether the substitution is enabled or disabled | ||
Returns: | ||
torch.fx.GraphModule | ||
""" | ||
|
||
def register_substitution(subgraph_insertion_fn): | ||
"""Function for use if substitution is enabled""" | ||
module_replacement = ModuleReplacement( | ||
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn | ||
) | ||
MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement | ||
return subgraph_insertion_fn | ||
|
||
def disable_substitution(subgraph_insertion_fn): | ||
"""Function for use if substitution is disabled""" | ||
return subgraph_insertion_fn | ||
|
||
return register_substitution if enabled else disable_substitution | ||
|
||
|
||
def pre_aot_module_replacement(gm: torch.fx.GraphModule): | ||
"""Perform module-level graph replacement prior to AOT tracing | ||
|
||
Args: | ||
gm: FX GraphModule to perform module replacement on | ||
Returns: | ||
torch.fx.GraphModule | ||
|
||
""" | ||
# Ensure all parameters are in inference mode | ||
for param in gm.parameters(): | ||
param.requires_grad = False | ||
|
||
# Iterate over graph nodes, extracting module calls, to check for interceptions | ||
for n in gm.graph.nodes: | ||
if n.op == "call_module": | ||
# Extract submodule from graph | ||
submodule = gm.get_submodule(n.target) | ||
|
||
# If submodule is a member of the substitution registry, replace it | ||
if type(submodule) in MODULE_SUBSTITUTION_REGISTRY: | ||
|
||
try: | ||
replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)] | ||
op, insertion_fn = ( | ||
replacement.new_operator, | ||
replacement.subgraph_insertion_fn, | ||
) | ||
logger.debug( | ||
f"Replacing module of type {type(submodule)} with {op}" | ||
) | ||
|
||
# Insert new node prior to older node | ||
with gm.graph.inserting_before(n): | ||
new_node = insertion_fn(gm, submodule, n) | ||
|
||
# If submodule is not a native torch.nn module, it must be manually excluded | ||
# from Dynamo tracing | ||
if not type(submodule).__module__.startswith("torch.nn"): | ||
torch._dynamo.allowed_functions._allowed_function_ids.add( | ||
id(type(submodule)) | ||
) | ||
|
||
# Replace all original node uses and clean up graph | ||
n.replace_all_uses_with(new_node) | ||
gm.graph.eliminate_dead_code() | ||
gm.graph.lint() | ||
gm.recompile() | ||
|
||
# A module replacement can fail in the event that the specific instance of the submodule cannot | ||
# be replaced | ||
except Exception: | ||
logger.debug( | ||
f"Encountered error while replacing {type(submodule)}", | ||
exc_info=True, | ||
) | ||
continue | ||
|
||
# Perform cleanup and recompilation before returning module | ||
gm.graph.eliminate_dead_code() | ||
gm.graph.lint() | ||
gm.recompile() | ||
return gm |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .maxpool1d import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from typing import Dict, Tuple | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be cool to have a sphinx tutorial on how to do this from an external user perspective. could be as easy as removing maxpool1d from the registry then walking through all the parts. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. I will add this, in addition to documentation on how to ensure all the relevant code is registered ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
import torch | ||
from torch._custom_op import custom_op | ||
from torch.fx.node import Argument, Target | ||
|
||
from torch_tensorrt.fx.converter_registry import tensorrt_converter | ||
from torch_tensorrt.fx.converters import acc_ops_converters | ||
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor | ||
|
||
from torch_tensorrt.dynamo.backend.lowering import module_substitution | ||
|
||
|
||
# This file serves as an example and a tutorial for excluding custom modules from | ||
# torch.compile tracing. Each required step is labeled with a number indicating the | ||
# preferable implementation order. | ||
|
||
|
||
# 1. The Placeholder | ||
# | ||
# Specify the schema and namespace of the operator, as well as a placeholder function | ||
# representing the schema. The schema should be in torch JIT syntax, indicating input and output | ||
# types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op | ||
# Then, create a placeholder function with no operations, but having the same schema and naming as that | ||
# used in the decorator | ||
Comment on lines
+13
to
+24
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added detailed instructions for creating a module substitution, to later be turned into a formatted |
||
@custom_op( | ||
"(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor", | ||
ns="tensorrt", | ||
) | ||
def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): | ||
# Defines operator schema, name, namespace, and function header | ||
... | ||
|
||
|
||
# 2. The Generic Implementation | ||
# | ||
# Define the default implementation of the operator in torch syntax. This is used for autograd | ||
# and other tracing functionality. Generally, the torch.nn.functional analog of the operator to replace | ||
# is desirable. If the operator to replace is a custom module you've written, then add its Torch | ||
# implementation here. Note that the function header to the generic function can have specific arguments | ||
# as in the above placeholder | ||
@maxpool1d.impl("cpu") | ||
@maxpool1d.impl("cuda") | ||
def maxpool1d_generic( | ||
*args, | ||
**kwargs, | ||
): | ||
# Defines an implementation for AOT Autograd to use for shape analysis/propagation | ||
return torch.nn.functional.max_pool1d( | ||
*args, | ||
**kwargs, | ||
) | ||
|
||
|
||
# 3. The Module Substitution Function | ||
# | ||
# Define a function which can intercept a node of the kind to be replaced, extract | ||
# the relevant data from that node/submodule, and then re-package the information | ||
# for use by an accelerated implementation (to be implemented in step 4). This function | ||
# should use the operator defined in step 1 (for example torch.ops.tensorrt.maxpool1d). | ||
# It should refactor the args and kwargs as is needed by the accelerated implementation. | ||
# | ||
# If the submodule has weights or other Tensor fields which the accelerated implementation | ||
# needs, the function should insert the necessary nodes to access those weights. For example, | ||
# if the weight Tensor of a submodule is needed, one could write: | ||
# | ||
# weights = gm.graph.get_attr(n.target + ".weight", torch.Tensor) | ||
# bias = gm.graph.get_attr(n.target + ".bias", torch.Tensor) | ||
# ... | ||
# kwargs={"weight": weights, | ||
# "bias": bias, | ||
# ... | ||
# | ||
@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) | ||
def maxpool1d_insertion_fn( | ||
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node | ||
) -> torch.fx.Node: | ||
# Defines insertion function for new node | ||
new_node = gm.graph.call_function( | ||
torch.ops.tensorrt.maxpool1d, | ||
args=node.args, | ||
kwargs={ | ||
"kernel_size": submodule.kernel_size, | ||
"stride": submodule.stride, | ||
"padding": submodule.padding, | ||
"dilation": submodule.dilation, | ||
"ceil_mode": submodule.ceil_mode, | ||
}, | ||
) | ||
|
||
return new_node | ||
|
||
|
||
# 4. The Accelerated Implementation | ||
# | ||
# Define an accelerated implementation of the operator, and register it as necessary. | ||
# This accelerated implementation should consume the args/kwargs specified in step 3. | ||
# One should expect that torch.compile will compress all kwargs into the args field in | ||
# the order specified in the schema written in step 1. | ||
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) | ||
def aten_ops_maxpool1d( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For module substitution "in library" do we want to put the converter here? or do we want to put the converter in the registry with the rest? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For external users they'd probably put it here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought it could be cleaner to have the converter implementation here, so all of the code relating to that module and its replacement is centralized. The requirement, however, is that for every new module replacement file, the user will have to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
network: TRTNetwork, | ||
target: Target, | ||
args: Tuple[Argument, ...], | ||
kwargs: Dict[str, Argument], | ||
name: str, | ||
) -> TRTTensor: | ||
# Defines converter replacing the default operator for this function | ||
kwargs_new = { | ||
"input": args[0], | ||
"kernel_size": args[1], | ||
"stride": args[2], | ||
"padding": args[3], | ||
"dilation": args[4], | ||
"ceil_mode": False if len(args) < 6 else args[5], | ||
} | ||
|
||
return acc_ops_converters.acc_ops_max_pool1d( | ||
network, target, None, kwargs_new, name | ||
) | ||
|
||
|
||
# 5. Add Imports | ||
# | ||
# Add your accelerated module file to the __init__.py in this directory, to ensure | ||
# all registrations are run. For instance, if the new module file is called new_mod.py, | ||
# one should add `from .new_mod import *` to the __init__.py |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import torch | ||
from utils import lower_graph_testing | ||
from torch.testing._internal.common_utils import run_tests, TestCase | ||
from torch_tensorrt.dynamo import compile | ||
|
||
|
||
class TestMaxPool1D(TestCase): | ||
def test_pre_aot_lowering_maxpool1d(self): | ||
class MaxPool1D(torch.nn.Module): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
self.maxpool = torch.nn.MaxPool1d(2) | ||
|
||
def forward(self, x): | ||
return self.maxpool(x) | ||
|
||
# Operations expected to be included in the traced graph after decompositions | ||
expected_ops = {torch.ops.tensorrt.maxpool1d.default} | ||
|
||
inputs = [ | ||
torch.rand( | ||
9, | ||
16, | ||
2, | ||
).cuda(), | ||
] | ||
|
||
fx_graph = torch.fx.symbolic_trace(MaxPool1D()) | ||
_, expected_ops_unseen = lower_graph_testing( | ||
fx_graph, inputs, expected_ops=expected_ops, min_block_size=1 | ||
) | ||
|
||
self.assertEquals( | ||
len(expected_ops_unseen), | ||
0, | ||
f"The following expected ops were not encountered: {expected_ops_unseen}", | ||
) | ||
|
||
torch._dynamo.reset() | ||
|
||
# Validate that the results between Torch and Torch-TRT are similar | ||
optimized_model = compile( | ||
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True | ||
) | ||
optimized_model_results = optimized_model(*inputs).detach().cpu() | ||
torch_model_results = fx_graph(*inputs).detach().cpu() | ||
|
||
max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results)) | ||
self.assertAlmostEqual( | ||
max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model." | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated registry to use strings formatted in a streamlined/uniform way via
get_qualified_name
as is already done in partitioning, and addressing the issues raised here: #1921 (comment)