Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Prototype Module-Acceleration in Dynamo #1921

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ commands:
name: Set up python environment
command: |
pip3 install --upgrade pip
pip3 install wheel setuptools
pip3 install wheel setuptools pyyaml
gs-olive marked this conversation as resolved.
Show resolved Hide resolved
pip3 install nvidia-pyindex
pip3 install tabulate
pip3 install tensorrt==<< parameters.trt-version-long >> nvidia-cudnn-cu11==<< parameters.cudnn-version-long >>
Expand Down
12 changes: 12 additions & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
pre_aot_module_replacement,
)
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
get_submod_inputs,
Expand Down Expand Up @@ -46,6 +49,13 @@ def aot_torch_tensorrt_aten_backend(
settings=settings,
)

logger.debug("Pre-module replacement graph:\n" + str(gm.graph))

# Enable Pre-AOT Lowering for Module-Level Replacement
gm = pre_aot_module_replacement(gm)

logger.debug("Post-module replacement graph:\n" + str(gm.graph))

# Invoke AOTAutograd to translate operators to aten
return aot_module_simplified(
gm,
Expand All @@ -71,6 +81,8 @@ def _pretraced_backend(
Compiled FX GraphModule
"""
try:
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

trt_compiled = _compile_module(
gm,
sample_inputs,
Expand Down
10 changes: 6 additions & 4 deletions py/torch_tensorrt/dynamo/backend/lowering/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
from ._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
get_submod_inputs,
from ._pre_aot_lowering import (
MODULE_SUBSTITUTION_REGISTRY,
module_substitution,
)
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
from .module_substitutions import *
12 changes: 10 additions & 2 deletions py/torch_tensorrt/dynamo/backend/lowering/_partition.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from typing import Dict, List, Optional, Sequence
from typing import Dict, List, Optional, Sequence, Set

import torch

from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.graph_module import GraphModule
from torch.fx.node import _get_qualified_name
Expand All @@ -14,6 +15,11 @@

logger = logging.getLogger(__name__)

DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
"torch.ops." + str(module.new_operator)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if theres a better type than string for this registry? Like is there a op type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will look into this more - there is the torch._ops.OpOverload type which could be a substitute here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After looking more into this, str is likely the best choice for this registry for now. The issue is that node targets can be a mix of functions and torch._ops.OpOverload objects. For example, torch.ops.aten.add.Tensor is an overload object representing a Tensor addition op, whereas the operator for get is an actual Python function. The unifying framework which can connect all of these types is the _get_qualified_name function, which can handle all of these types, and returns a string. I have updated the implementation here to use that function.

for module in MODULE_SUBSTITUTION_REGISTRY.values()
)


class TRTPartitioner(CapabilityBasedPartitioner):
"""Partitioner to split an FX graph into subgraphs based on operator support
Expand All @@ -35,7 +41,9 @@ def __init__(
operator_support: OperatorSupport,
*,
non_compute_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[
Sequence[str]
] = DEFAULT_SINGLE_NODE_PARTITIONS,
min_block_size=MIN_BLOCK_SIZE,
) -> None:
super().__init__(
Expand Down
119 changes: 119 additions & 0 deletions py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict
import torch
import logging


logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class ModuleReplacement:
"""Class to store key functionality for module replacement"""

# torch.ops.___ name for replacement function for module
new_operator: torch._ops.OpOverload

# Function taking a containing graph, a submodule, and a 'call_module' node and returning
# a replacement node, with type 'call_function', or raising an Error if incompatibility is detected
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
subgraph_insertion_fn: Callable[
[torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node
]


# Dictionary mapping module to ModuleReplacement instance
MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = dict()
gs-olive marked this conversation as resolved.
Show resolved Hide resolved


def module_substitution(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is sick 😄

module_to_replace: torch.nn.Module,
gs-olive marked this conversation as resolved.
Show resolved Hide resolved
new_operator: torch._ops.OpOverload,
enabled: bool = True,
) -> Callable[[Any], Any]:
"""Decorator to register subgraph insertion functions

Args:
module_to_replace: nn.Module to replace
new_operator: Custom torch operator to replace with
enabled: Whether the substitution is enabled or disabled
Returns:
torch.fx.GraphModule
"""

def register_substitution(subgraph_insertion_fn):
"""Function for use if substitution is enabled"""
module_replacement = ModuleReplacement(
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn
)
MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement
return subgraph_insertion_fn

def disable_substitution(subgraph_insertion_fn):
"""Function for use if substitution is disabled"""
return subgraph_insertion_fn

return register_substitution if enabled else disable_substitution


def pre_aot_module_replacement(gm: torch.fx.GraphModule):
"""Perform module-level graph replacement prior to AOT tracing

Args:
gm: FX GraphModule to perform module replacement on
Returns:
torch.fx.GraphModule

"""
# Ensure all parameters are in inference mode
for param in gm.parameters():
param.requires_grad = False

# Iterate over graph nodes, extracting module calls, to check for interceptions
for n in gm.graph.nodes:
if n.op == "call_module":
# Extract submodule from graph
submodule = gm.get_submodule(n.target)

# If submodule is a member of the substitution registry, replace it
if type(submodule) in MODULE_SUBSTITUTION_REGISTRY:

try:
replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)]
op, insertion_fn = (
replacement.new_operator,
replacement.subgraph_insertion_fn,
)
logger.debug(
f"Replacing module of type {type(submodule)} with {op}"
)

# Insert new node prior to older node
with gm.graph.inserting_before(n):
new_node = insertion_fn(gm, submodule, n)

# If submodule is not a native torch.nn module, it must be manually excluded
# from Dynamo tracing
if not type(submodule).__module__.startswith("torch.nn"):
torch._dynamo.allowed_functions._allowed_function_ids.add(
id(type(submodule))
)

# Replace all original node uses and clean up graph
n.replace_all_uses_with(new_node)
gm.graph.eliminate_dead_code()
gm.recompile()

# A module replacement can fail in the event that the specific instance of the submodule cannot
# be replaced
except Exception:
logger.debug(
f"Encountered error while replacing {type(submodule)}",
exc_info=True,
)
continue

# Perform cleanup and recompilation before returning module
gm.graph.eliminate_dead_code()
gm.recompile()
return gm
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .maxpool1d import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Dict, Tuple
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be cool to have a sphinx tutorial on how to do this from an external user perspective. could be as easy as removing maxpool1d from the registry then walking through all the parts.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I will add this, in addition to documentation on how to ensure all the relevant code is registered (__init__.py imports)

Copy link
Collaborator Author

@gs-olive gs-olive Jun 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid clashing with or making dependencies on #1966 and #1967, I'll add more commenting to this file and then port it over to a formatted .py file once those PRs are merged.

import torch
from torch._custom_op import custom_op
from torch.fx.node import Argument, Target

from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from torch_tensorrt.dynamo.backend.lowering import module_substitution


@custom_op(
"(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor",
ns="tensorrt",
)
def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
# Defines operator schema, name, namespace, and function header
...


@maxpool1d.impl("cpu")
@maxpool1d.impl("cuda")
def maxpool1d_generic(
*args,
**kwargs,
):
# Defines a converter implementation for AOT Autograd to use for shape analysis/propagation
gs-olive marked this conversation as resolved.
Show resolved Hide resolved
return torch.nn.functional.max_pool1d(
*args,
**kwargs,
)


@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
def aten_ops_maxpool1d(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For module substitution "in library" do we want to put the converter here? or do we want to put the converter in the registry with the rest?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For external users they'd probably put it here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it could be cleaner to have the converter implementation here, so all of the code relating to that module and its replacement is centralized.

The requirement, however, is that for every new module replacement file, the user will have to add from .my_module_replacement import * to module_substitutions/__init__.py to ensure the registrations occur.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten_ops --> tensorrt

network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
# Defines converter replacing the default operator for this function
kwargs_new = {
"input": args[0],
"kernel_size": args[1],
"stride": args[2],
"padding": args[3],
"dilation": args[4],
"ceil_mode": False if len(args) < 6 else args[5],
}

return acc_ops_converters.acc_ops_max_pool1d(
network, target, None, kwargs_new, name
)


@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
def maxpool1d_insertion_fn(
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
) -> torch.fx.Node:
# Defines insertion function for new node
new_node = gm.graph.call_function(
torch.ops.tensorrt.maxpool1d,
args=node.args,
kwargs={
"kernel_size": submodule.kernel_size,
"stride": submodule.stride,
"padding": submodule.padding,
"dilation": submodule.dilation,
"ceil_mode": submodule.ceil_mode,
},
)

return new_node
55 changes: 55 additions & 0 deletions py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from utils import lower_graph_testing
from torch.testing._internal.common_utils import run_tests, TestCase
from torch_tensorrt.dynamo import compile


class TestMaxPool1D(TestCase):
def test_pre_aot_lowering_maxpool1d(self):
class MaxPool1D(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.maxpool = torch.nn.MaxPool1d(2)

def forward(self, x):
return self.maxpool(x)

# Operations expected to be included in the traced graph after decompositions
expected_ops = {torch.ops.tensorrt.maxpool1d.default}

inputs = [
torch.rand(
9,
16,
2,
).cuda(),
]

fx_graph = torch.fx.symbolic_trace(MaxPool1D())
_, expected_ops_unseen = lower_graph_testing(
fx_graph, inputs, expected_ops=expected_ops, min_block_size=1
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = compile(
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results))
self.assertAlmostEqual(
max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model."
)


if __name__ == "__main__":
run_tests()
5 changes: 5 additions & 0 deletions py/torch_tensorrt/dynamo/backend/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
)
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
pre_aot_module_replacement,
)

from torch._dynamo.backends.common import fake_tensor_unsupported

Expand All @@ -31,6 +34,8 @@ def fx_dynamo_testing_backend(
torch_executed_ops=torch_executed_ops,
)

gm = pre_aot_module_replacement(gm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a separate testing backend? Would we need to continue to make changes to this in step with the actual one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason for the separate testing backend is the argument store_intermediate_graphs. This is needed to track intermediate partitioned modules at different points in the compilation, to ensure decompositions, fusions, etc. are being utilized.

As changes are made to the main backend, yes, those changes would need to be reflected here, and in the compile_module_testing function below, however these are higher-level functions which do not change often.


# Invoke AOTAutograd to translate operators to aten
return aot_module_simplified(
gm,
Expand Down