Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Prototype Module-Acceleration in Dynamo #1921

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,8 @@ def run(self):
ext_modules=ext_modules,
install_requires=[
"torch >=2.1.dev,<2.2" if not LEGACY else "torch >=1.13.0,<2.0",
"pyyaml",
"packaging",
],
setup_requires=[],
cmdclass={
Expand Down
12 changes: 12 additions & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
pre_aot_module_replacement,
)
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
get_submod_inputs,
Expand Down Expand Up @@ -46,6 +49,13 @@ def aot_torch_tensorrt_aten_backend(
settings=settings,
)

logger.debug("Pre-module replacement graph:\n" + str(gm.graph))

# Enable Pre-AOT Lowering for Module-Level Replacement
gm = pre_aot_module_replacement(gm)

logger.debug("Post-module replacement graph:\n" + str(gm.graph))

# Invoke AOTAutograd to translate operators to aten
return aot_module_simplified(
gm,
Expand All @@ -71,6 +81,8 @@ def _pretraced_backend(
Compiled FX GraphModule
"""
try:
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

trt_compiled = _compile_module(
gm,
sample_inputs,
Expand Down
10 changes: 6 additions & 4 deletions py/torch_tensorrt/dynamo/backend/lowering/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
from ._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
get_submod_inputs,
from ._pre_aot_lowering import (
MODULE_SUBSTITUTION_REGISTRY,
module_substitution,
)
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
from .module_substitutions import *
12 changes: 10 additions & 2 deletions py/torch_tensorrt/dynamo/backend/lowering/_partition.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from typing import Dict, List, Optional, Sequence
from typing import Dict, List, Optional, Sequence, Set

import torch

from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.graph_module import GraphModule
from torch.fx.node import _get_qualified_name
Expand All @@ -14,6 +15,11 @@

logger = logging.getLogger(__name__)

DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
_get_qualified_name(module.new_operator)
for module in MODULE_SUBSTITUTION_REGISTRY.values()
)
Comment on lines +18 to +21
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated registry to use strings formatted in a streamlined/uniform way via get_qualified_name as is already done in partitioning, and addressing the issues raised here: #1921 (comment)



class TRTPartitioner(CapabilityBasedPartitioner):
"""Partitioner to split an FX graph into subgraphs based on operator support
Expand All @@ -35,7 +41,9 @@ def __init__(
operator_support: OperatorSupport,
*,
non_compute_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[
Sequence[str]
] = DEFAULT_SINGLE_NODE_PARTITIONS,
min_block_size=MIN_BLOCK_SIZE,
) -> None:
super().__init__(
Expand Down
121 changes: 121 additions & 0 deletions py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Type
import torch
import logging


logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class ModuleReplacement:
"""Class to store key functionality for module replacement"""

# torch.ops.___ name for replacement function for module
new_operator: torch._ops.OpOverload

# Function taking a containing graph, a submodule, and a 'call_module' node and returning
# a replacement node, with type 'call_function', or raising an Error if incompatibility is detected
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
subgraph_insertion_fn: Callable[
[torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node
]


# Dictionary mapping module to ModuleReplacement instance
MODULE_SUBSTITUTION_REGISTRY: Dict[Type[torch.nn.Module], ModuleReplacement] = dict()


def module_substitution(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is sick 😄

module_to_replace: Type[torch.nn.Module],
new_operator: torch._ops.OpOverload,
enabled: bool = True,
) -> Callable[[Any], Any]:
"""Decorator to register subgraph insertion functions

Args:
module_to_replace: nn.Module to replace
new_operator: Custom torch operator to replace with
enabled: Whether the substitution is enabled or disabled
Returns:
torch.fx.GraphModule
"""

def register_substitution(subgraph_insertion_fn):
"""Function for use if substitution is enabled"""
module_replacement = ModuleReplacement(
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn
)
MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement
return subgraph_insertion_fn

def disable_substitution(subgraph_insertion_fn):
"""Function for use if substitution is disabled"""
return subgraph_insertion_fn

return register_substitution if enabled else disable_substitution


def pre_aot_module_replacement(gm: torch.fx.GraphModule):
"""Perform module-level graph replacement prior to AOT tracing

Args:
gm: FX GraphModule to perform module replacement on
Returns:
torch.fx.GraphModule

"""
# Ensure all parameters are in inference mode
for param in gm.parameters():
param.requires_grad = False

# Iterate over graph nodes, extracting module calls, to check for interceptions
for n in gm.graph.nodes:
if n.op == "call_module":
# Extract submodule from graph
submodule = gm.get_submodule(n.target)

# If submodule is a member of the substitution registry, replace it
if type(submodule) in MODULE_SUBSTITUTION_REGISTRY:

try:
replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)]
op, insertion_fn = (
replacement.new_operator,
replacement.subgraph_insertion_fn,
)
logger.debug(
f"Replacing module of type {type(submodule)} with {op}"
)

# Insert new node prior to older node
with gm.graph.inserting_before(n):
new_node = insertion_fn(gm, submodule, n)

# If submodule is not a native torch.nn module, it must be manually excluded
# from Dynamo tracing
if not type(submodule).__module__.startswith("torch.nn"):
torch._dynamo.allowed_functions._allowed_function_ids.add(
id(type(submodule))
)

# Replace all original node uses and clean up graph
n.replace_all_uses_with(new_node)
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()

# A module replacement can fail in the event that the specific instance of the submodule cannot
# be replaced
except Exception:
logger.debug(
f"Encountered error while replacing {type(submodule)}",
exc_info=True,
)
continue

# Perform cleanup and recompilation before returning module
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
return gm
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .maxpool1d import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import Dict, Tuple
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be cool to have a sphinx tutorial on how to do this from an external user perspective. could be as easy as removing maxpool1d from the registry then walking through all the parts.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I will add this, in addition to documentation on how to ensure all the relevant code is registered (__init__.py imports)

Copy link
Collaborator Author

@gs-olive gs-olive Jun 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid clashing with or making dependencies on #1966 and #1967, I'll add more commenting to this file and then port it over to a formatted .py file once those PRs are merged.

import torch
from torch._custom_op import custom_op
from torch.fx.node import Argument, Target

from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from torch_tensorrt.dynamo.backend.lowering import module_substitution


# This file serves as an example and a tutorial for excluding custom modules from
# torch.compile tracing. Each required step is labeled with a number indicating the
# preferable implementation order.


# 1. The Placeholder
#
# Specify the schema and namespace of the operator, as well as a placeholder function
# representing the schema. The schema should be in torch JIT syntax, indicating input and output
# types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op
# Then, create a placeholder function with no operations, but having the same schema and naming as that
# used in the decorator
Comment on lines +13 to +24
Copy link
Collaborator Author

@gs-olive gs-olive Jun 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added detailed instructions for creating a module substitution, to later be turned into a formatted .py file

@custom_op(
"(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor",
ns="tensorrt",
)
def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
# Defines operator schema, name, namespace, and function header
...


# 2. The Generic Implementation
#
# Define the default implementation of the operator in torch syntax. This is used for autograd
# and other tracing functionality. Generally, the torch.nn.functional analog of the operator to replace
# is desirable. If the operator to replace is a custom module you've written, then add its Torch
# implementation here. Note that the function header to the generic function can have specific arguments
# as in the above placeholder
@maxpool1d.impl("cpu")
@maxpool1d.impl("cuda")
def maxpool1d_generic(
*args,
**kwargs,
):
# Defines an implementation for AOT Autograd to use for shape analysis/propagation
return torch.nn.functional.max_pool1d(
*args,
**kwargs,
)


# 3. The Module Substitution Function
#
# Define a function which can intercept a node of the kind to be replaced, extract
# the relevant data from that node/submodule, and then re-package the information
# for use by an accelerated implementation (to be implemented in step 4). This function
# should use the operator defined in step 1 (for example torch.ops.tensorrt.maxpool1d).
# It should refactor the args and kwargs as is needed by the accelerated implementation.
#
# If the submodule has weights or other Tensor fields which the accelerated implementation
# needs, the function should insert the necessary nodes to access those weights. For example,
# if the weight Tensor of a submodule is needed, one could write:
#
# weights = gm.graph.get_attr(n.target + ".weight", torch.Tensor)
# bias = gm.graph.get_attr(n.target + ".bias", torch.Tensor)
# ...
# kwargs={"weight": weights,
# "bias": bias,
# ...
#
@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
def maxpool1d_insertion_fn(
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
) -> torch.fx.Node:
# Defines insertion function for new node
new_node = gm.graph.call_function(
torch.ops.tensorrt.maxpool1d,
args=node.args,
kwargs={
"kernel_size": submodule.kernel_size,
"stride": submodule.stride,
"padding": submodule.padding,
"dilation": submodule.dilation,
"ceil_mode": submodule.ceil_mode,
},
)

return new_node


# 4. The Accelerated Implementation
#
# Define an accelerated implementation of the operator, and register it as necessary.
# This accelerated implementation should consume the args/kwargs specified in step 3.
# One should expect that torch.compile will compress all kwargs into the args field in
# the order specified in the schema written in step 1.
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
def aten_ops_maxpool1d(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For module substitution "in library" do we want to put the converter here? or do we want to put the converter in the registry with the rest?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For external users they'd probably put it here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it could be cleaner to have the converter implementation here, so all of the code relating to that module and its replacement is centralized.

The requirement, however, is that for every new module replacement file, the user will have to add from .my_module_replacement import * to module_substitutions/__init__.py to ensure the registrations occur.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten_ops --> tensorrt

network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
# Defines converter replacing the default operator for this function
kwargs_new = {
"input": args[0],
"kernel_size": args[1],
"stride": args[2],
"padding": args[3],
"dilation": args[4],
"ceil_mode": False if len(args) < 6 else args[5],
}

return acc_ops_converters.acc_ops_max_pool1d(
network, target, None, kwargs_new, name
)


# 5. Add Imports
#
# Add your accelerated module file to the __init__.py in this directory, to ensure
# all registrations are run. For instance, if the new module file is called new_mod.py,
# one should add `from .new_mod import *` to the __init__.py
55 changes: 55 additions & 0 deletions py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from utils import lower_graph_testing
from torch.testing._internal.common_utils import run_tests, TestCase
from torch_tensorrt.dynamo import compile


class TestMaxPool1D(TestCase):
def test_pre_aot_lowering_maxpool1d(self):
class MaxPool1D(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.maxpool = torch.nn.MaxPool1d(2)

def forward(self, x):
return self.maxpool(x)

# Operations expected to be included in the traced graph after decompositions
expected_ops = {torch.ops.tensorrt.maxpool1d.default}

inputs = [
torch.rand(
9,
16,
2,
).cuda(),
]

fx_graph = torch.fx.symbolic_trace(MaxPool1D())
_, expected_ops_unseen = lower_graph_testing(
fx_graph, inputs, expected_ops=expected_ops, min_block_size=1
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = compile(
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results))
self.assertAlmostEqual(
max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model."
)


if __name__ == "__main__":
run_tests()
Loading