-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Prototype Module-Acceleration in Dynamo #1921
Conversation
880d3e2
to
7e0f440
Compare
f34c25c
to
4acaae2
Compare
c3c266d
to
613b773
Compare
- Add support for excluding entire Torch modules from tracing in Dynamo using Torch custom operators - Develop new dataclass to store required replacement functions and operators in a streamlined way - Add new registry to store mapping between replacement operators and their corresponding dataclass - Add documentation for easy additions of new module-level exclusion operators
613b773
to
47a7f69
Compare
47a7f69
to
60df50e
Compare
@@ -14,6 +15,11 @@ | |||
|
|||
logger = logging.getLogger(__name__) | |||
|
|||
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set( | |||
"torch.ops." + str(module.new_operator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know if theres a better type than string for this registry? Like is there a op type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will look into this more - there is the torch._ops.OpOverload
type which could be a substitute here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After looking more into this, str
is likely the best choice for this registry for now. The issue is that node targets can be a mix of functions and torch._ops.OpOverload
objects. For example, torch.ops.aten.add.Tensor
is an overload object representing a Tensor addition op, whereas the operator for get
is an actual Python function. The unifying framework which can connect all of these types is the _get_qualified_name
function, which can handle all of these types, and returns a string. I have updated the implementation here to use that function.
MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = dict() | ||
|
||
|
||
def module_substitution( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is sick 😄
@@ -0,0 +1,75 @@ | |||
from typing import Dict, Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be cool to have a sphinx tutorial on how to do this from an external user perspective. could be as easy as removing maxpool1d from the registry then walking through all the parts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. I will add this, in addition to documentation on how to ensure all the relevant code is registered (__init__.py
imports)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py
Outdated
Show resolved
Hide resolved
|
||
|
||
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) | ||
def aten_ops_maxpool1d( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For module substitution "in library" do we want to put the converter here? or do we want to put the converter in the registry with the rest?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For external users they'd probably put it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought it could be cleaner to have the converter implementation here, so all of the code relating to that module and its replacement is centralized.
The requirement, however, is that for every new module replacement file, the user will have to add from .my_module_replacement import *
to module_substitutions/__init__.py
to ensure the registrations occur.
@@ -31,6 +34,8 @@ def fx_dynamo_testing_backend( | |||
torch_executed_ops=torch_executed_ops, | |||
) | |||
|
|||
gm = pre_aot_module_replacement(gm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is there a separate testing backend? Would we need to continue to make changes to this in step with the actual one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason for the separate testing backend is the argument store_intermediate_graphs
. This is needed to track intermediate partitioned modules at different points in the compilation, to ensure decompositions, fusions, etc. are being utilized.
As changes are made to the main backend, yes, those changes would need to be reflected here, and in the compile_module_testing
function below, however these are higher-level functions which do not change often.
5c829d2
to
af53282
Compare
af53282
to
46cbe6e
Compare
- Fix typing issues, add depedencies to `setup.py`, add qualified name checking for module registry - Add detailed tutorial descriptions to sample module substitution with step-by-step detailed instructions for creating a new module substitution
46cbe6e
to
e4c6ba5
Compare
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set( | ||
_get_qualified_name(module.new_operator) | ||
for module in MODULE_SUBSTITUTION_REGISTRY.values() | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated registry to use strings formatted in a streamlined/uniform way via get_qualified_name
as is already done in partitioning, and addressing the issues raised here: #1921 (comment)
# This file serves as an example and a tutorial for excluding custom modules from | ||
# torch.compile tracing. Each required step is labeled with a number indicating the | ||
# preferable implementation order. | ||
|
||
|
||
# 1. The Placeholder | ||
# | ||
# Specify the schema and namespace of the operator, as well as a placeholder function | ||
# representing the schema. The schema should be in torch JIT syntax, indicating input and output | ||
# types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op | ||
# Then, create a placeholder function with no operations, but having the same schema and naming as that | ||
# used in the decorator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added detailed instructions for creating a module substitution, to later be turned into a formatted .py
file
# One should expect that torch.compile will compress all kwargs into the args field in | ||
# the order specified in the schema written in step 1. | ||
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) | ||
def aten_ops_maxpool1d( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aten_ops
--> tensorrt
Closed in place of #1979 |
Description
Fixes #1894
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: